版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433787
torch.gather(input, dim, index, out=None) 和 torch.scatter_(dim, index, src)是一对作用相反的方法
先来看torch.gather, 核心操作其实就是这样:
outik = inputindex[i][j][k]k # if dim == 0
outik = inputiindexik]k # if dim == 1
outik = inputi[indexik] # if dim == 2
是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index
官方文档给的例子是:
代码语言:javascript复制>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
具体过程就是这里的input = [1,2,3,4], index = [0,0,1,0], dim = 1, 则
out0 = input0 index0 ] = input0 = 1
out0 = input0 index0 ] = input0 = 1
out1 = input1 index1 ] = input1 = 4
out1 = input1 index1 ] = input1 = 3
torch.scatter_(dim, index, src)
核心操作:
self index[i][j][k] k = srcik # if dim == 0
self i indexik ] k = srcik # if dim == 1
self i [ indexik ] = srcik # if dim == 2
这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:
代码语言:javascript复制x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
此例中,src就是x,index就是[0, 1, 2, 0, 0, 2, 0, 0, 1, 2], dim=0
我们把src写在左边,把self写在右边,这样好理解一些,
但要注意是把src的值赋给self,所以用箭头指过去:
0.4319 = Src0 ----->self index[0][0] ----> self0
0.6500 = Src0 ----->self index[0][1] ----> self1
0.4080 = Src0 ----->self index[0][2] ----> self2
0.8760 = Src0 ----->self index[0][3] ----> self0
0.2355 = Src0 ----->self index[0][4] ----> self0
0.2609 = Src1 ----->self index[1][0] ----> self2
0.4711 = Src1 ----->self index[1][1] ----> self0
0.8486 = Src1 ----->self index[1][2] ----> self0
0.8573 = Src1 ----->self index[1][3] ----> self1
0.1029 = Src1 ----->self index[1][4] ----> self2
则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好,
剩下的未得到分配的位置,就填0补充。