01 笔记分享 : pytorch中max和nonzero使用

2021-05-28 17:29:20 浏览数 (1)

第一部分:torch.max()

1. 官网链接

https://pytorch.org/docs/stable/generated/torch.max.html#torch.max

2. 解释 案例

2.1 torch.max(input) → Tensor

返回input中所有元素中的最大值

案例:

a = torch.randn(1, 3)

a

输出:

tensor([[ 0.0557, -0.7400, -0.8941]])

torch.max(a)

tensor(0.0557)

a.max()

tensor(0.0557)

2.2 torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)

返回一个namedtuple(values, indices), values是在指定dim下,input中每行的最大值;indices是最大值所在索引。

如果keepdim=True,output的size与input size保持一致(此种情况除外: input的维度为1,即dim=1)

如果keepdim=False,dim会被torch.squeeze缩小/压缩,导致output tensors的维度为1,少于input的维度。

案例1:

a = torch.randn(4,4 ) # 随机生成4行4列的数据

a

输出:

tensor([[ 1.1982, -0.2496, -0.3671, -1.2475], [-1.6641, 0.6409, 0.9440, -0.1829], [ 0.9641, -0.1747, -1.1281, 0.4016], [ 0.3706, 0.8722, -1.1174, -0.5317]])

torch.max(a, dim=0) # 横轴方向

输出:

torch.return_types.max( values=tensor([1.1982, 0.8722, 0.9440, 0.4016]), indices=tensor([0, 3, 1, 2]))

torch.max(a, dim=0, keepdim=True) # 横轴方向,keepdim=True

输出:

torch.return_types.max( values=tensor([[1.1982, 0.8722, 0.9440, 0.4016]]), indices=tensor([[0, 3, 1, 2]]))

torch.max(a, dim=1) # 纵轴方向

输出:

torch.return_types.max( values=tensor([1.1982, 0.9440, 0.9641, 0.8722]), indices=tensor([0, 2, 0, 1]))

torch.max(a, dim=1, keepdim=True) # 纵轴方向,keepdim=True

输出:

torch.return_types.max( values=tensor([[1.1982], [0.9440], [0.9641], [0.8722]]), indices=tensor([[0], [2], [0], [1]]))

第二部分:torch.nonzero()

1. 官网参考链接 :

https://pytorch.org/docs/stable/generated/torch.nonzero.html#torch-nonzero

2. 方法

torch.nonzero(input, *, out=None, as_tuple=False)

-> LongTensor or tuple of LongTensors

返回值: 默认返回一个2-D的tensor,包含非零值的索引。

官网解释:

torch.nonzero(..., as_tuple=False) (default) returns a 2-D tensor where each row is the index for a nonzero value.

torch.nonzero(..., as_tuple=True) returns a tuple of 1-D index tensors, allowing for advanced indexing, so x[x.nonzero(as_tuple=True)] gives all nonzero values of tensor x. Of the returned tuple, each index tensor contains nonzero indices for a certain dimension.

案例1:

torch.nonzero(torch.tensor([1,1,1,0,1]))

输出:

tensor([[0], [1], [2], [4]])

torch.nonzero(torch.tensor([1,1,1,0,1]), as_tuple=True) # 设置as_tuple

输出:

(tensor([0, 1, 2, 4]),)

案例2:

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],

[0.0, 0.0, 0.4, 0.0],

[0.0, 0.0, 1.2, 0.0],

[0.0, 0.0, 0.0,-0.4]]))

输出:

tensor([[0, 0], [1, 2], [2, 2], [3, 3]])

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],

[0.0, 0.0, 0.4, 0.0],

[0.0, 0.0, 1.2, 0.0],

[0.0, 0.0, 0.0,-0.4]]), as_tuple=True) # 设置as_tuple

# 返回值: 指定行索引,包含非零值

输出:

(tensor([0, 1, 2, 3]), tensor([0, 2, 2, 3]))

案例3:

t = torch.tensor([[0.6, 1.5, 2.3, 3.7],

[2.4, 0.0, 0.4, 1.8],

[5.1, 0.0, 1.2, 3.4],

[6.3, 4.8, 0.0,-0.4]])

t

输出:

tensor([[ 0.6000, 1.5000, 2.3000, 3.7000], [ 2.4000, 0.0000, 0.4000, 1.8000], [ 5.1000, 0.0000, 1.2000, 3.4000], [ 6.3000, 4.8000, 0.0000, -0.4000]])

t[:, 2:] > 0

输出:

tensor([[ True, True], [ True, True], [ True, True], [False, False]])

(t[:, 2:] > 0).nonzero(as_tuple=False)

# 返回一个2-D的tensor,包含非零值的索引

输出:

tensor([[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]])

(t[:, 2:] > 0).nonzero(as_tuple=False).T # 转置

输出:

tensor([[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]])

i, j = (t[:, 2:] > 0).nonzero(as_tuple=False).T

输出:

i # 行索引

tensor([0, 0, 1, 1, 2, 2])

j # 列索引

tensor([0, 1, 0, 1, 0, 1])

0 人点赞