pytorch的python API略读--tensor(三)

2022-07-04 14:00:01 浏览数 (1)

2.1.2 索引

筛选出符合某种条件的subtensor。

torch.where: 根据布尔变量的值选择tensor中的元素,用法如下:

代码语言:javascript复制
torch.where(condition, x, y)

下面举个简单的例子:

代码语言:javascript复制
>>> import torch
>>> cvtutorials = torch.randn(3, 4)
>>> threshold = torch.zeros(3, 4)
>>> cvtutorials
tensor([[-1.6981,  1.0443,  2.7922, -0.8736],
        [-2.0208, -0.4815, -0.1488, -0.9714],
        [ 1.1035,  0.4089,  0.6279,  2.4600]])
>>> torch.where(cvtutorials > 0, cvtutorials, threshold)
tensor([[0.0000, 1.0443, 2.7922, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [1.1035, 0.4089, 0.6279, 2.4600]])

上面torch.where函数返回tensor的某个元素的值遵循这样的选择:如果cvtutorials中的某个元素大于0,那么保留,否则设置为0,用数学公式表达如下:

torch.index_select: 沿着某个维度,通过index对输入tensor进行筛选。用法如下:

代码语言:javascript复制
torch.index_select(input, dim, index, *, out=None)

举个例子说明下:

代码语言:javascript复制
>>> cvtutorials = torch.randn(2,3)
>>> cvtutorials
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> indices = torch.tensor([0, 1])
>>> torch.index_select(cvtutorials, 0, indices)
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> torch.index_select(cvtutorials, 1, indices)
tensor([[-0.9935, -0.9802],
        [ 2.6251, -1.0099]])

torch.masked_select: 根据设置的mask,返回一个一维的tensor(向量)。用法如下:

代码语言:javascript复制
torch.masked_select(input, mask, *, out=None)

举个简单的例子:

代码语言:javascript复制
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 1.1016, -1.5259,  1.1065],
        [ 0.4838, -0.5521,  0.1556]])
>>> mask = torch.tensor([[False, True, True], [True, False, False]])
>>> mask
tensor([[False,  True,  True],
        [ True, False, False]])
>>> torch.masked_select(cvtutorials, mask)
tensor([-1.5259,  1.1065,  0.4838])

从中可以看出,根据mask对输入tensor相应位置的元素进行筛选,mask某位置为True,则取出tensor相应位置的元素,否则,不取出。

还有一点,mask的shape不一定和tensor一样,但是需要broadcast到tensor上,例如:

代码语言:javascript复制
>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 0.8686,  0.0910,  1.8702],
        [ 1.8140, -1.0902,  0.7051]])
>>> mask = torch.tensor([[False, True, True]])
>>> mask
tensor([[False,  True,  True]])
>>> torch.masked_select(cvtutorials, mask)
tensor([ 0.0910,  1.8702, -1.0902,  0.7051])

0 人点赞