nonzero
前面已经介绍了 index_select 和 mask_select 两个选择函数,这两个函数通过一定的索引规则从输入张量中筛选出满足条件的元素值,只不过 index_select 函数使用索引 index 的索引规则,而 mask_select 函数使用布尔掩码 mask 的索引规则。
本小节介绍的 torch.nonzero(input, out = None, as_tuple = False) 函数与前面两个选择函数最大的不同是:「nonzero 函数返回的是输入张量中非零元素的索引而不是输入张量中符合索引规则的元素值 (index_select 和 mask_select)」,nonzero 函数的参数有:
- input (Tensor) - 输入张量;
- out (Tensor, optional) - 指定输出的张量。比如执行 torch.zeros([2, 2], out = tensor_a),相当于执行 tensor_a = torch.zeros([2, 2]);
- as_tuple = False (Boolean) - 如果 as_tuple 为 False (默认值),返回一个包含输入张量中非零元素的索引的 2D 张量;如果 as_tuple 为 True,对于输入张量的每一个维度都返回一个 1D 张量,1D 张量中的元素是沿着该维度上非零元素的索引;
参数 as_tuple 的取值决定了 nonzero 函数最终呈现的输出形式,接下来以参数 as_tuple 的参数值为 False 或 True 来分别介绍 nonzero 函数。
1. 当 as_tuple = False (默认)
torch.nonzero(input, out = None, as_tuple = False) 函数返回一个 2D 张量,2D 张量中的每一行都是输入张量中非零元素值的索引。
代码语言:javascript复制>>> import torch
>>> # 输入张量为1D张量
>>> input_1d = torch.tensor([1, 1, 0, 1])
>>> output_1d = torch.nonzero(input_1d, as_tuple = False)
>>> print(output_1d.size())
torch.Size([3, 1])
>>> print(output_1d)
tensor([[0],
[1],
[3]])
>>> # 输入张量为2D张量
>>> input_2d = torch.tensor([[0, 1], [2, 3]])
>>> output_2d = torch.nonzero(input_2d, as_tuple = False)
>>> print(output_2d.size())
torch.Size([3, 2])
>>> print(output_2d)
tensor([[0, 1],
[1, 0],
[1, 1]])
>>> # 输入张量为3D张量
>>> input_3d = torch.tensor([[[0, 1],[3, 4]],
[[1, 0],[0, 0]]])
>>> output_3d = torch.nonzero(input_3d, as_tuple = False)
>>> print(output_3d.size())
torch.Size([4, 3])
>>> print(output_3d)
tensor([[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0]])
这里以 2D 张量为例,简单分析当 as_tuple = False 时的 nonzero 函数,此时的 2D 输入张量为:
2D 输入张量可以看成大家熟悉的矩阵,通过矩阵中的行和列可以索引矩阵中任意元素,此时矩阵中有 3 个非零元素:
- 1: 位于矩阵的第一行第二列,index_1 = [0, 1]
- 2: 位于矩阵的第二行第一列,index_2 = [1, 0]
- 3: 位于矩阵的第二行第二列,index_3 = [1, 1]
使用 torch.nonzero(input, out = None, as_tuple = False) 函数返回的是一个形状为 (3 x 2) 的 2D 张量 torch.tensor([[0, 1], [1, 0], [1, 1]])
,2D 张量一共有 3 个行,每一个行都是一个非零元素的索引,即 torch.tensor([index_1, index_2, index_3])
。
当 as_tuple = False (默认) 时的 nonzero 函数需要注意两点:
- 函数总是返回 2D 张量;
- 如果输入张量的维度为 n,且非零元素个数为 z,则 nonzero 函数返回的是一个形状为 (z x n) 的 2D 张量。比如对于一个非零元素个数为 4 的 3D 输入张量来说,输入张量的维度为 3 且一共有 4 个非零元素,因此 nonzero 函数返回的是一个形状为 (4 x 3) 的 2D 张量;
2. 当 as_tuple = True
torch.nonzero(input, out = None, as_tuple = True) 函数返回元素为 1D 张量的元组,每一个 1D 张量对应输入张量的一个维度,而 1D 张量中的每个元素值表示输入张量中的非零元素在该维度上的索引。
代码语言:javascript复制>>> import torch
>>> # 输入张量为1D张量
>>> input_1d = torch.tensor([1, 1, 0, 1])
>>> output_1d = torch.nonzero(input_1d, as_tuple = True)
>>> print(output_1d)
(tensor([0, 1, 3]),)
>>> # 输入张量为2D张量
>>> input_2d = torch.tensor([[0, 1], [1, 2]])
>>> output_2d = torch.nonzero(input_2d, as_tuple = True)
>>> print(output_2d)
(tensor([0, 1, 1]), tensor([1, 0, 1]))
>>> # 输入张量为3D张量
>>> input_3d = torch.tensor([[[0, 1],[3, 4]],
[[1, 0],[0, 0]]])
>>> output_3d = torch.nonzero(input_3d, as_tuple = True)
>>> print(output_3d)
(tensor([0, 0, 0, 1]), tensor([0, 1, 1, 0]), tensor([1, 0, 1, 0]))
此处的代码只是对前面 as_tuple = False 时的代码进行了两处修改:
- 将 as_tuple 的参数值从 False 改成 True;
- 删除了打印输出张量形状的语句,因为当 as_tuple = True 时,nonzero 函数返回的是一个元组,而元组并有形状一说;
因为 2D 张量可以看成矩阵方便描述,因此同样以大家熟悉的 2D 张量为例,简单分析当 as_tuple = True 时的 nonzero 函数。此时 nonzero 函数返回的元组为 (tensor([0, 1, 1]), tensor([1, 0, 1]))
,元组中的两个 1D 张量分别对应矩阵的行和列:
- 对应矩阵行的 1D 张量中的 3 个元素值分别对应矩阵中 3 个非零元素的行索引;
- 对应矩阵列的 1D 张量中的 3 个元素值分别对应矩阵中 3 个非零元素的列索引;
此时矩阵中有 3 个非零元素:
- 1: 位于矩阵的第一行第二列,index_1_row = 0, index_1_col = 1
- 2: 位于矩阵的第二行第一列,index_2_row = 1, index_2_col = 0
- 3: 位于矩阵的第二行第二列,index_3_row = 1, index_3_col = 1
使用 torch.nonzero(input, out = None, as_tuple = True) 函数返回长度为 2 的元组,元组中的每一个元素都是一个形状为 (3, ) 的 1D 张量 torch.tensor([0, 1, 1])
和 torch.tensor([1, 0, 1])
,元组中的每 1D 张量对应输入张量的一个维度,而每个 1D 张量的元素值分别对应输入张量中非零元素在对应维度上的索引,即 (torch.tensor([index_1_row, index_2_row, index_3_row]), torch.tensor([index_1_col, index_2_col, index_3_col]))
。
当 as_tuple = True 时的 nonzero 函数需要注意三点:
- 函数总是返回一个元组;
- 如果输入张量的维度为 n,且非零元素个数为 z,则 nonzero 函数返回的是一个长度为 n 的元组,元组中的每一个元素都是一个形状为 (z, ) 的 1D 张量。比如对于一个非零元素个数为 4 的 3D 输入张量来说,输入张量的维度为 3 且一共有 4 个非零元素,因此 nonzero 函数返回的是一个长度为 3 的元组,元组中的每一个元素都是一个形状为 (4, ) 的 1D 张量;
- 如果了解高级索引会发现其实当 as_tuple = True 时的 nonzero 函数返回的是一个高级索引。
>>> import torch
>>> # 输入张量为2D张量
>>> input_2d = torch.tensor([[0, 1], [1, 2]])
>>> output_2d = torch.nonzero(input_2d, as_tuple = True)
>>> print(input_2d)
tensor([[0, 1],
[1, 2]])
>>> print(output_2d)
(tensor([0, 1, 1]), tensor([1, 0, 1]))
>>> # 使用高级索引索引输入张量中的非零元素
>>> # 通过索引元组获取其中的1D张量
>>> # output_2d[0] = tensor([0, 1, 1])
>>> # output_2d[1] = tensor([1, 0, 1])
>>> print(input_2d[output_2d[0], output_2d[1]])
tensor([1, 1, 2])
- 还有一种特殊的情况,当输入张量为一个非零元素值的 0D 张量 (非零标量),此时的 nonzero 函数将这个 0D 张量看成是只有一个非零元素值的 1D 张量;
>>> import torch
>>> input_0d = torch.tensor(2)
>>> output_0d = torch.nonzero(input_0d, as_tuple = True)
>>> print(output_0d)
(tensor([0]),)
>>> input_1d = torch.tensor([2])
>>> output_1d = torch.nonzero(input_1d, as_tuple = True)
>>> print(output_1d)
(tensor([0]),)
>>> print(output_0d == output_1d)
True