前言
切片其实也是索引操作,所以切片经常被称为切片索引,为了更方便叙述,本文将切片称为切片索引。索引和切片操作可以帮助我们快速提取张量中的部分数据。
1. 基本索引
PyTorch 支持与 Python 和 NumPy 类似的基本索引操作,PyTorch 中的基本索引可以通过整数值来索引张量。
代码语言:javascript复制>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> print(a[0]) # 索引张量a的第一行
tensor([0, 1, 2])
>>> print(a[0][1]) # 索引张量a的第一行和第二列
tensor(1)
变量 a 是一个(3 x 3)的 2D 张量,即张量 a 包含两个维度:
- 第一个维度,在 2D 张量中称为行维度;
- 第二个维度,在 2D 张量中称为列维度;
a[0]
表示在张量 a 的行维度上取索引号为 0 的元素(第一行);a[0][1]
表示在张量 a 的行维度上取索引号为 0 的元素(第一行)以及在列维度上取索引号为 1 的元素(第二列),获取行维度和列维度上的元素集合的交集(位于第一行第二列上的元素集合)即为最终的索引结果。简单来说,[i][j]...[k]
中的每一个[]
都表示张量的一个维度,从左边开始维度依次增加,而[]
中的元素值代表对应维度的索引号,「此时的索引号可以为负数,相当于从后向前索引。」
>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> print(a[-1]) # 索引张量a的最后一行
tensor([6, 7, 8])
「当张量的维度数较高的时候,使用[i][j]...[k]
**的方式书写非常不方便,可以采用[i, j,...,k]
的方式,两种方式是等价的。」
>>> import torch
>>> # 构造形状为2x2x3,元素值从0到11的3D张量
>>> a = torch.arange(12).view([2, 2, 3])
>>> print(a)
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
>>> # 第一个维度取索引号为0的元素
>>> # 第二个维度取索引号为1的元素
>>> # 第三个维度取索引号为2的元素
>>> # 满足这三个条件的元素即为索引结果
>>> print(a[0, 1, 2])
tensor(5)
>>> # 通过基本索引修改元素值
>>> a[0, 1, 2] = 100
>>> print(a)
tensor([[[ 0, 1, 2],
[ 3, 4, 100]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
通过对比原始张量 a 和通过基本索引的方式修改元素值之后的张量 a 可以发现,「通过基本索引出来的结果与原始的张量共享内存,如果修改一个,另一个也会被修改。」
2. 切片索引
通过 [start: end: steps](起始位置为start,终止位置为end,步长为steps)的方式索引连续的张量子集。以形状为 [4, 3, 28, 28] 的图片张量为例,在 PyTorch 中图片张量的格式为 [batch_size, channel, width, hight],[4, 3, 28, 28] 的图片张量表示 4 张拥有 RGB 三个通道且每个通道为 (28 x 28) 的像素矩阵。
代码语言:javascript复制>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取前2张图片
>>> print(a[:2].size())
torch.Size([2, 3, 28, 28])
>>> # 读取前两张图片的R通道的28x28的像素矩阵
>>> print(a[:2, :1, :, :].size())
torch.Size([2, 1, 28, 28])
>>> # 读取前两张图片的GB通道的28x28的像素矩阵
>>> print(a[:2, 1:, :, :].size())
torch.Size([2, 2, 28, 28])
>>> # 读取前两张图片的B通道的28x28的像素矩阵
>>> print(a[:2, -1:, :, :].size())
torch.Size([2, 1, 28, 28])
start: end: step切片方式有很多简写方式,其中 start、end、step 3 个参数可以根据需要选择性的省略,全部省略时即为::,表示从最考试读取到最末尾,步长为 1,即不跳过任何元素。如 x[0,::] 表示读取第一张图片的的所有通道的像素矩阵,其中::表示在通道维度上读取所有RGB三个通道,它等价于 x[0] 的写法。通常为了简洁,将::简写成单个冒号。
代码语言:javascript复制>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取第一张图片
>>> print(a[0,::].size())
torch.Size([3, 28, 28])
>>> # 为了更加简介,::可以简写为单个冒号:
>>> print(a[0,:].size())
torch.Size([3, 28, 28])
接下来总结一下start: end: step 切片的简写方式,其中从第一个元素读取时 start 可以省略,即 start = 0 是可以省略的,取到最后一个元素时 end 可以省略,步长为 1 时 step 可以省略,简写方式总结如表 4.1:
「还有点需要注意,在 PyTorch 中切片索引中的步长不能小于0,即不能为负数。」
代码语言:javascript复制>>> import torch
>>> # 创建元素值为0~8的1D张量
>>> a = torch.arange(9)
>>> print(a)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> print(a[4: 0: -2])
Traceback (most recent call last):
File "/home/chenkc/code/tensor.py", line 44, in <module>
print(a[4: 0: -2])
ValueError: step must be greater than zero
当张量的维度数量较多时,不需要采样的维度一般用单冒号 : 表示采样所有元素,此时有可能出现大量的 : 出现。
代码语言:javascript复制>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 获取4张图片的RGB三个通道的所有行和第三列像素矩阵
>>> print(a[:, :, :, 2].size())
torch.Size([4, 3, 28])
「为了避免出现像x[:, :, :, 2] 这样过多冒号的情况,可以使用...符号表示取多个维度上所有数据,其中维度的数量需要根据规则自动推断:当切片方式出现...符号时,...符号左边的维度将自动对齐到最左边,...符号右边的维度将自动对齐到最右边,此时系统再自动推断...符号代表的维度张量,」 它的切片方式总结如表 4.2 所示(「其中表中的···都为...」)。
3. 高级索引
PyTorch 支持绝大多数 NumPy 的高级索引,高级索引可以看成是基本索引的扩展。
代码语言:javascript复制>>> import torch
>>> a = torch.arange(9).view([3, 3])
>>> print(a)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> print(a[[0, 1],...])
tensor([[0, 1, 2],
[3, 4, 5]])
>>> print(a[[0, 1], [1, 2]])
tensor([1, 5])
>>> print(a[[1, 0, 2], [0]])
tensor([3, 0, 6])
这里给出了 PyTorch 中的三种高级索引方式,通过这些高级索引的输出结果,可以看出这些高级索引的本质。
- a[[0, 1, ...]] 等价 a[0] 和 a[1],相当于索引张量的第一行和第二行元素;
- a[[0, 1, 1, 2]] 等价 a[0, 1] 和 a[1, 2],相当于索引张量的第一行的第二列和第二行的第三列元素;
- a[[1, 0, 2, 0]] 等价 a[1, 0] 和 a[0, 0] 和 a[2, 0],相当于索引张量的第二行第一列的元素、张量第一行和第一列的元素以及张量第三行和第一列的元素;
References:
1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm
2. 初探Numpy中的花式索引
原文地址:https://mp.weixin.qq.com/s?