gather
torch.gather(*input,dim,index,sparse_grad=False, out=None*)
函数沿着指定的轴 dim 上的索引 index 采集输入张量 input 中的元素值,函数的参数有:
- input (Tensor) - 输入张量
- dim (int) - 需要进行索引的轴
- index (LongTensor) - 要采集元素的索引
- sparse_grad (bool, optional) - 如果为 True,输入张量 input 会变成离散张量
- out (Tensor, optional) - 指定输出的张量。比如执行 torch.zeros(2, 2, out = tensor_a),相当于执行 tensor_a = torch.zeros(2, 2)
除了 sparse_grad 和 out 两个可选参数,其余三个参数都是必选参数。为了方便这里只考虑必选参数,即 torch.gather(input, dim, index)。
简单介绍完 gather 函数之后,来看一个简单的小例子:一次将下面 2D 张量中所有红色的元素采集出来。
2D 张量可以看成矩阵,2D 张量的第一个维度为矩阵的行 (dim = 0),2D 张量的第二个维度为矩阵的列 (dim = 1),从左向右依次看三个红色元素在矩阵中的具体位置:
- 6: 第 2 行的第 0 列
- 1: 第 0 行的第 1 列
- 5: 第 1 行的第 2 列
通过红色元素的具体位置可以看出,三个红色元素的列索引号是有规律的:从 0 到 2 逐渐递增。假设此时列索引的规律是已知并且固定的,我们只需要给出这些红色元素在行上的索引号就可以将这些红色元素全部采集出来。
至此,对于这个 2D 张量的小例子,已知了输入张量和指定行上的索引号。回顾 torch.gather(input, dim, index) 函数沿着指定轴上的索引采集输入张量的元素值,貌似现在已知的条件和 gather 函数中所需要的参数有些谋和。下面我们来尝试一下使用 gather 函数来采集红色元素。
代码语言:txt复制>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index = torch.tensor([[2, 0, 1]])
>>> # dim=0: 行上的索引
>>> out = torch.gather(x, dim = 0, index = index)
>>> print(out)
tensor([[6, 1, 5]])
gather 函数的输出结果和我们在小例子中分析的结果一致。
如果按照从上到下来看三个红色元素,采集元素的顺序和从前面从左向右看的时候不同,此时采集元素的顺序为 1, 5, 6,现在看看此时这三个红色元素在矩阵中的具体位置:
- 1: 第 0 行的第 1 列
- 5: 第 1 行的第 2 列
- 6: 第 2 行的第 0 列
现在行索引号是有规律的:从 0 到 2 逐渐递增。现在假设此时行索引的规律是已知并且固定的,我们只需要给出这些红色元素在列上的索引号就可以将这些红色元素全部采集出来了。
代码语言:txt复制>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index = torch.tensor([[1, 2, 0]]).t()
>>> # dim=1: 在列方向上索引
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)
tensor([[1],
[5],
[6]])
在不同轴上 (行或列) 进行索引传入的 index 参数的张量形状不同,在 gather 函数中规定:
- 传入 index 的张量维度数要和输入张量 input 的维度数相同;
- 输出张量的形状和传入 index 张量的形状相同;
- 如果沿着轴的每个维度采集 N 个元素,则 index 对应轴上的长度为 N (N ≥ 1)。比如对于前面的 2D 张量,对行索引且每一行只采集一个元素,则 index 在行上的长度为 1,index 的形状即为 (1 x 3);
接下来使用一个形状为 (3 x 5) 2D 张量来详细的分析 gather 函数的原理。
2D 张量有两个轴,假定现在只采集一个元素:
- dim = 0
dim = 0 表示在行上索引,此时假定已知且固定了在列上的索引,即 (其中 ? 为待采集元素在行上的索引号):
- 在 ? 行的第 0 列
- 在 ? 行的第 1 列
- 在 ? 行的第 2 列
- 在 ? 行的第 3 列
- 在 ? 行的第 4 列
如果想要使用 gather 函数采集元素,需要在 index 中指定 5 个行索引号,而每列只索引一个元素且在行上索引 (dim = 0),因此最终我们需要传入 index 张量的形状为 (1, 5),其中的元素值为待采集元素的行索引号。
- dim = 1
dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引,即 (其中 ? 为待采集元素在列上的索引号):
- 在 0 行的第 ? 列
- 在 1 行的第 ? 列
- 在 2 行的第 ? 列
如果想要使用 gather 函数采集元素,需要在 index 中指定 3 个列索引号,而每行只索引一个元素且在列上索引 (dim = 1),因此最终我们需要传入 index 张量的形状为 (1, 3),其中的元素值为待采集元素的列索引号。
最后来看看如何使用 gather 函数每行采集两个元素:
代码语言:txt复制>>> import torch
>>> x = torch.arange(15).view(3, 5)
>>> index = torch.LongTensor([[0, 1], [2, 3], [1, 2]])
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)
tensor([[ 0, 1],
[ 7, 8],
[11, 12]])
传入 index 的张量形状为 (3 x 2),因此最终输出张量的形状也为 (3 x 2)。dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引:
- 在 0 行的第 0 列,在 0 行的第 1 列
- 在 1 行的第 2 列,在 1 行的第 3 列
- 在 2 行的第 1 列,在 2 行的第 2 列