torch.gather() 和torch.sactter_()的用法简析

2019-05-25 22:42:06 浏览数 (1)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1433787

torch.gather(input, dim, index, out=None)torch.scatter_(dim, index, src)是一对作用相反的方法

先来看torch.gather, 核心操作其实就是这样:

outik = inputindex[i][j][k]k # if dim == 0

outik = inputiindexik]k # if dim == 1

outik = inputi[indexik] # if dim == 2

是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index

官方文档给的例子是:

代码语言:javascript复制
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]

具体过程就是这里的input = [1,2,3,4], index = [0,0,1,0], dim = 1, 则

out0 = input0 index0 ] = input0 = 1

out0 = input0 index0 ] = input0 = 1

out1 = input1 index1 ] = input1 = 4

out1 = input1 index1 ] = input1 = 3

torch.scatter_(dim, index, src)

核心操作:

self index[i][j][k] k = srcik # if dim == 0

self i indexik ] k = srcik # if dim == 1

self i [ indexik ] = srcik # if dim == 2

这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:

代码语言:javascript复制
x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

此例中,src就是x,index就是[0, 1, 2, 0, 0, 2, 0, 0, 1, 2], dim=0

我们把src写在左边,把self写在右边,这样好理解一些,

但要注意是把src的值赋给self,所以用箭头指过去:

0.4319 = Src0 ----->self index[0][0] ----> self0

0.6500 = Src0 ----->self index[0][1] ----> self1

0.4080 = Src0 ----->self index[0][2] ----> self2

0.8760 = Src0 ----->self index[0][3] ----> self0

0.2355 = Src0 ----->self index[0][4] ----> self0

0.2609 = Src1 ----->self index[1][0] ----> self2

0.4711 = Src1 ----->self index[1][1] ----> self0

0.8486 = Src1 ----->self index[1][2] ----> self0

0.8573 = Src1 ----->self index[1][3] ----> self1

0.1029 = Src1 ----->self index[1][4] ----> self2

则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好

剩下的未得到分配的位置,就填0补充。

0 人点赞