代码语言:javascript复制
import torch
a=torch.randint(-1,2,(10,),dtype=torch.int)
print(a)
print(a.size())
print(torch.nonzero(a))
print(torch.nonzero(a).size())
Output:
-------------------------------------------------------------------------
tensor([ 0, -1, 1, 1, -1, 0, 1, -1, -1, -1], dtype=torch.int32)
torch.Size([10])
tensor([[1],
[2],
[3],
[4],
[6],
[7],
[8],
[9]])
torch.Size([8, 1])
-------------------------------------------------------------------------
也就是说torch.nonezero()的作用就是找到tensor中所有不为0的索引。(要注意返回值的size)