Matmul 函数
torch.matmul(input, other, out = None)
函数对 input 和 other 两个张量进行矩阵相乘。torch.matmul 函数根据传入参数的张量维度有很多重载函数。为了方便后续的介绍,将传入 input 参数中的张量命名为 a,而传入 other 参数的张量命名为 b。
- 若 a 为 1D 张量,b 为 2D 张量,torch.matmul 函数:
- 首先,在 1D 张量 a 的前面插入一个长度为 1 的新维度变成 2D 张量;
- 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
- 最后,将矩阵乘积结果中长度为 1 的维度(前面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
import torch
# a为1D张量,b为2D张量
a = torch.tensor([1., 2.])
b = torch.tensor([[5., 6., 7.], [8., 9., 10.]])
result = torch.matmul(a, b)
print(result.size())
# torch.Size([3])
print(result)
# tensor([21., 24., 27.])
- 若 a 为 2D 张量,b 为 1D 张量,torch.matmul 函数:
- 首先,在 1D 张量 b 的后面插入一个长度为 1 的新维度变成 2D 张量;
- 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
- 最后,将矩阵乘积结果中长度为 1 的维度(后面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
import torch
# a为2D张量,b为1D张量
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.tensor([7., 8., 9.])
result = torch.matmul(a, b)
print(result.size())
# torch.Size([2])
print(result)
# tensor([50., 122.])
具体细节和 a 为 1D 张量,b 为 2D 张量的情况差不多,只不过,一个在 1D 张量的前面插入长度为 1 的新维度(a 为 1D 张量,b 为 2D 张量),另一个是在 1D 张量的后面插入长度为 1 的新维度(a 为 2D 张量,b 为 1D 张量)。