PyTorch入门笔记-常见的矩阵乘法

2021-03-16 11:06:14 浏览数 (1)

前言

前文介绍了根据传入参数的张量维度决定其实现功能的 torch.matmul 函数。torch.matmul 函数功能强大,虽然可以使用其重载的运算符 @,但是使用起来比较麻烦,并且在实际使用场景中,常用的矩阵乘积运算就那么几种。为了方便使用这些常用的矩阵乘积运算,PyTorch 提供了一些更为方便的函数。

二维矩阵乘法

神经网络中包含大量的 2D 张量矩阵乘法运算,而使用 torch.matmul 函数比较复杂,因此 PyTorch 提供了更为简单方便的 torch.mm(input, other, out = None) 函数。下表是 torch.matmul 函数和 torch.mm 函数的简单对比。

torch.matmul 函数支持广播,主要指的是当参与矩阵乘积运算的两个张量中其中有一个是 1D 张量,torch.matmul 函数会将其广播成 2D 张量参与运算,最后将广播添加的维度删除作为最终 torch.matmul 函数的返回结果。torch.mm 函数不支持广播,相对应的输入的两个张量必须为 2D。

代码语言:javascript复制
import torch

input = torch.tensor([[1., 2.], [3., 4.]])
other = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

result = torch.mm(input, other)
print(result)
# tensor([[21., 24., 27.],
#         [47., 54., 61.]])

批量矩阵乘法

同理,由于 torch.bmm 函数不支持广播,相对应的输入的两个张量必须为 3D。

代码语言:javascript复制
import torch

input = torch.randn(10, 3, 4)
other = torch.randn(10, 4, 2)

result = torch.bmm(input, other)

print(result.size())
# torch.Size([10, 3, 2])

0 人点赞