前言
前文介绍了根据传入参数的张量维度决定其实现功能的 torch.matmul 函数。torch.matmul 函数功能强大,虽然可以使用其重载的运算符 @
,但是使用起来比较麻烦,并且在实际使用场景中,常用的矩阵乘积运算就那么几种。为了方便使用这些常用的矩阵乘积运算,PyTorch 提供了一些更为方便的函数。
二维矩阵乘法
神经网络中包含大量的 2D 张量矩阵乘法运算,而使用 torch.matmul 函数比较复杂,因此 PyTorch 提供了更为简单方便的 torch.mm(input, other, out = None)
函数。下表是 torch.matmul 函数和 torch.mm 函数的简单对比。
torch.matmul 函数支持广播,主要指的是当参与矩阵乘积运算的两个张量中其中有一个是 1D 张量,torch.matmul 函数会将其广播成 2D 张量参与运算,最后将广播添加的维度删除作为最终 torch.matmul 函数的返回结果。torch.mm 函数不支持广播,相对应的输入的两个张量必须为 2D。
代码语言:javascript复制import torch
input = torch.tensor([[1., 2.], [3., 4.]])
other = torch.tensor([[5., 6., 7.], [8., 9., 10.]])
result = torch.mm(input, other)
print(result)
# tensor([[21., 24., 27.],
# [47., 54., 61.]])
批量矩阵乘法
同理,由于 torch.bmm 函数不支持广播,相对应的输入的两个张量必须为 3D。
代码语言:javascript复制import torch
input = torch.randn(10, 3, 4)
other = torch.randn(10, 4, 2)
result = torch.bmm(input, other)
print(result.size())
# torch.Size([10, 3, 2])