加、减、乘、除
加、减、乘、除是最基本的数学运算,分别通过 torch.add
、torch.sub
、torch.mul
和 torch.div
函数实现,Pytorch 已经重载了 、-、* 和 / 运算符。
import torch
x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)
# add加法
print(torch.add(x, y))
# tensor([[1., 2.],
# [3., 4.]])
print(x y)
# tensor([[1., 2.],
# [3., 4.]])
# subtraction减法
print(torch.sub(x, y))
# tensor([[ 1., 0.],
# [-1., -2.]])
print(x - y)
# tensor([[ 1., 0.],
# [-1., -2.]])
# multiplication乘法
print(torch.mul(x, y))
# tensor([[0., 1.],
# [2., 3.]])
print(x * y)
# tensor([[0., 1.],
# [2., 3.]])
# division除法
print(torch.div(x, y))
# tensor([[ inf, 1.0000],
# [0.5000, 0.3333]])
print(x / y)
# tensor([[ inf, 1.0000],
# [0.5000, 0.3333]])
这里需要注意,张量 y 的第一个元素为 0,而在 x 和 y 进行除法运算时,y 中的 0 作为了除数。在 PyTorch 中,除数为 0 时程序并不会报错,而是的等于 inf。
这些加、减、乘、除基本的数学运算在 PyTorch 中的实现都比较简单,但是在使用过程中还是需要注意以下几点(下面都以乘法为例,其余三种运算同理):
- 参与基本数学运算的张量必须形状一致,或者可以通过广播机制扩展到相同的形状;
import torch
x = torch.ones(1, 2)
y = torch.arange(4).reshape(2, 2)
# 此时的x通过广播机制形状变成(2, 2)
print(x * y)
# tensor([[0., 1.],
# [2., 3.]])
# 此时将张量y的形状变成(1, 4)
y = y.reshape(1, 4)
# 此时x和y不满足广播机制
print(x * y)
'''
Traceback (most recent call last):
File "/home/chenkc/code/pytorch/test01.py", line 224, in <module>
print(x * y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1
'''
- 基本的数学运算与 NumPy 一样,都是 Element-Wise(逐元素运算),因此
torch.mul
实现的并不是张量乘法(两个张量相乘后的张量形状遵循:中间相等取两头的规则),而是相乘张量中对应位置的元素相乘;
import torch
x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)
# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
# [2., 3.]])
# 矩阵乘法
# 矩阵相乘需要保证张量中元素一致
y = y.float()
print(torch.matmul(x, y))
# tensor([[2., 4.],
# [2., 4.]])
- 基本的数学运算支持两种接口,换句话说,可以使用
tensor.add
、tensor.sub
、tensor.mul
和tensor.div
;
import torch
x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)
# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
# [2., 3.]])
print(y.mul(x))
# tensor([[0., 1.],
# [2., 3.]])
- 基本的数学运算也支持原地操作(in-place operation);
import torch
x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)
# 逐元素相乘
print(y.mul(x))
# tensor([[0., 1.],
# [2., 3.]])
print(y) # 张量y没有改变
# tensor([[0, 1],
# [2, 3]])
y = y.float()
print(y.mul_(x))
# tensor([[0., 1.],
# [2., 3.]])
print(y) # 张量y = x * y
# tensor([[0., 1.],
# [2., 3.]])
类型陷阱
本小节我们一共使用了 2 次 y = y.float
,第一次在第 2 点演示矩阵乘法(torch.matmul(x, y))之前,第二次在第 4 点演示原地操作(y.mul_(x))之前。这是因为生成张量 x 和 y 的类型不一致,当然本小节使用的都是 torch.arange
函数生成张量 y,这也是为了说明类型陷阱的问题。
import torch
x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)
print(x.dtype)
# torch.float32
print(y.dtype)
# torch.int64
虽然加减乘除基本运算对张量的类型没有要求,但是有一些运算操作对运算的张量类型还是比较敏感的。
- 矩阵乘法要求相乘的张量类型一致;
- 原地操作由于将运算后的张量赋值给原始张量,但是如果运算后的张量和原始张量的类型不一样,也会抛出错误。比如张量 y 为
torch.int64
,x * y
后的张量为torch.float32
类型,将torch.float32
类型的张量赋值给torch.int64
的张量 y,程序会抛出错误;