PyTorch入门笔记-张量的运算和类型陷阱

2021-03-16 11:00:03 浏览数 (1)

加、减、乘、除

加、减、乘、除是最基本的数学运算,分别通过 torch.addtorch.subtorch.multorch.div 函数实现,Pytorch 已经重载了 、-、* 和 / 运算符。

代码语言:javascript复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# add加法
print(torch.add(x, y))
# tensor([[1., 2.],
#         [3., 4.]])
print(x   y)
# tensor([[1., 2.],
#         [3., 4.]])

# subtraction减法
print(torch.sub(x, y))
# tensor([[ 1.,  0.],
#         [-1., -2.]])
print(x - y)
# tensor([[ 1.,  0.],
#         [-1., -2.]])

# multiplication乘法
print(torch.mul(x, y))
# tensor([[0., 1.],
#         [2., 3.]])
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# division除法
print(torch.div(x, y))
# tensor([[   inf, 1.0000],
#         [0.5000, 0.3333]])
print(x / y)
# tensor([[   inf, 1.0000],
#         [0.5000, 0.3333]])

这里需要注意,张量 y 的第一个元素为 0,而在 x 和 y 进行除法运算时,y 中的 0 作为了除数。在 PyTorch 中,除数为 0 时程序并不会报错,而是的等于 inf。

这些加、减、乘、除基本的数学运算在 PyTorch 中的实现都比较简单,但是在使用过程中还是需要注意以下几点(下面都以乘法为例,其余三种运算同理):

  1. 参与基本数学运算的张量必须形状一致,或者可以通过广播机制扩展到相同的形状;
代码语言:javascript复制
import torch

x = torch.ones(1, 2)
y = torch.arange(4).reshape(2, 2)

# 此时的x通过广播机制形状变成(2, 2)
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# 此时将张量y的形状变成(1, 4)
y = y.reshape(1, 4)
# 此时x和y不满足广播机制
print(x * y)
'''
Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/test01.py", line 224, in <module>
    print(x * y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1
'''
  1. 基本的数学运算与 NumPy 一样,都是 Element-Wise(逐元素运算),因此 torch.mul 实现的并不是张量乘法(两个张量相乘后的张量形状遵循:中间相等取两头的规则),而是相乘张量中对应位置的元素相乘;
代码语言:javascript复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

# 矩阵乘法
# 矩阵相乘需要保证张量中元素一致
y = y.float()
print(torch.matmul(x, y))
# tensor([[2., 4.],
#         [2., 4.]])
  1. 基本的数学运算支持两种接口,换句话说,可以使用 tensor.addtensor.subtensor.multensor.div
代码语言:javascript复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(x * y)
# tensor([[0., 1.],
#         [2., 3.]])

print(y.mul(x))
# tensor([[0., 1.],
#         [2., 3.]])
  1. 基本的数学运算也支持原地操作(in-place operation);
代码语言:javascript复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

# 逐元素相乘
print(y.mul(x))
# tensor([[0., 1.],
#         [2., 3.]])

print(y) # 张量y没有改变
# tensor([[0, 1],
#         [2, 3]])

y = y.float()
print(y.mul_(x))
# tensor([[0., 1.],
#         [2., 3.]])

print(y) # 张量y = x * y
# tensor([[0., 1.],
#         [2., 3.]])

类型陷阱

本小节我们一共使用了 2 次 y = y.float,第一次在第 2 点演示矩阵乘法(torch.matmul(x, y))之前,第二次在第 4 点演示原地操作(y.mul_(x))之前。这是因为生成张量 x 和 y 的类型不一致,当然本小节使用的都是 torch.arange 函数生成张量 y,这也是为了说明类型陷阱的问题。

代码语言:javascript复制
import torch

x = torch.ones(2, 2)
y = torch.arange(4).reshape(2, 2)

print(x.dtype)
# torch.float32

print(y.dtype)
# torch.int64

虽然加减乘除基本运算对张量的类型没有要求,但是有一些运算操作对运算的张量类型还是比较敏感的。

  • 矩阵乘法要求相乘的张量类型一致;
  • 原地操作由于将运算后的张量赋值给原始张量,但是如果运算后的张量和原始张量的类型不一样,也会抛出错误。比如张量 y 为 torch.int64x * y 后的张量为 torch.float32 类型,将 torch.float32 类型的张量赋值给 torch.int64 的张量 y,程序会抛出错误;

0 人点赞