PyTorch入门笔记-张量相乘matmul函数02

2021-03-16 11:02:32 浏览数 (1)

Matmul 函数

torch.matmul(input, other, out = None) 函数对 input 和 other 两个张量进行矩阵相乘。torch.matmul 函数根据传入参数的张量维度有很多重载函数。为了方便后续的介绍,将传入 input 参数中的张量命名为 a,而传入 other 参数的张量命名为 b。

  • 若 a 为 1D 张量,b 为 2D 张量,torch.matmul 函数:
    • 首先,在 1D 张量 a 的前面插入一个长度为 1 的新维度变成 2D 张量;
    • 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
    • 最后,将矩阵乘积结果中长度为 1 的维度(前面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
代码语言:javascript复制
import torch

# a为1D张量,b为2D张量
a = torch.tensor([1., 2.])
b = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

result = torch.matmul(a, b)
print(result.size())
# torch.Size([3])

print(result)
# tensor([21., 24., 27.])

  • 若 a 为 2D 张量,b 为 1D 张量,torch.matmul 函数:
    • 首先,在 1D 张量 b 的后面插入一个长度为 1 的新维度变成 2D 张量;
    • 然后,在满足第一个 2D 张量(矩阵)的列数(column)和第二个 2D 张量(矩阵)的行数(row)相同的条件下,两个 2D 张量矩阵乘积,否则会抛出错误;
    • 最后,将矩阵乘积结果中长度为 1 的维度(后面插入的长度为 1 的新维度)删除作为最终 torch.matmul 函数返回的结果;
代码语言:javascript复制
import torch

# a为2D张量,b为1D张量
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.tensor([7., 8., 9.])

result = torch.matmul(a, b)
print(result.size())
# torch.Size([2])

print(result)
# tensor([50., 122.])

具体细节和 a 为 1D 张量,b 为 2D 张量的情况差不多,只不过,一个在 1D 张量的前面插入长度为 1 的新维度(a 为 1D 张量,b 为 2D 张量),另一个是在 1D 张量的后面插入长度为 1 的新维度(a 为 2D 张量,b 为 1D 张量)。

0 人点赞