PyTorch入门笔记-索引和切片

2022-04-26 15:49:32 浏览数 (1)

前言

切片其实也是索引操作,所以切片经常被称为切片索引,为了更方便叙述,本文将切片称为切片索引。索引和切片操作可以帮助我们快速提取张量中的部分数据。

1. 基本索引

PyTorch 支持与 Python 和 NumPy 类似的基本索引操作,PyTorch 中的基本索引可以通过整数值来索引张量。

代码语言:javascript复制
>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]) 

>>> print(a[0]) # 索引张量a的第一行

tensor([0, 1, 2])

>>> print(a[0][1]) # 索引张量a的第一行和第二列

tensor(1)

变量 a 是一个(3 x 3)的 2D 张量,即张量 a 包含两个维度:

  • 第一个维度,在 2D 张量中称为行维度;
  • 第二个维度,在 2D 张量中称为列维度;

a[0]表示在张量 a 的行维度上取索引号为 0 的元素(第一行);a[0][1]表示在张量 a 的行维度上取索引号为 0 的元素(第一行)以及在列维度上取索引号为 1 的元素(第二列),获取行维度和列维度上的元素集合的交集(位于第一行第二列上的元素集合)即为最终的索引结果。简单来说,[i][j]...[k]中的每一个[]都表示张量的一个维度,从左边开始维度依次增加,而[]中的元素值代表对应维度的索引号,「此时的索引号可以为负数,相当于从后向前索引。」

代码语言:javascript复制
>>> import torch
>>> # 构造形状为3x3,元素值从0到8的2D张量
>>> a = torch.arange(0, 9).view([3, 3])
>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> print(a[-1]) # 索引张量a的最后一行

tensor([6, 7, 8])

「当张量的维度数较高的时候,使用[i][j]...[k]**的方式书写非常不方便,可以采用[i, j,...,k]的方式,两种方式是等价的。」

代码语言:javascript复制
>>> import torch
>>> # 构造形状为2x2x3,元素值从0到11的3D张量
>>> a = torch.arange(12).view([2, 2, 3])
>>> print(a)

tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

>>> # 第一个维度取索引号为0的元素
>>> # 第二个维度取索引号为1的元素
>>> # 第三个维度取索引号为2的元素
>>> # 满足这三个条件的元素即为索引结果
>>> print(a[0, 1, 2])

tensor(5)

>>> # 通过基本索引修改元素值
>>> a[0, 1, 2] = 100
>>> print(a)

tensor([[[  0,   1,   2],
         [  3,   4, 100]],

        [[  6,   7,   8],
         [  9,  10,  11]]])

通过对比原始张量 a 和通过基本索引的方式修改元素值之后的张量 a 可以发现,「通过基本索引出来的结果与原始的张量共享内存,如果修改一个,另一个也会被修改。」

2. 切片索引

通过 [start: end: steps](起始位置为start,终止位置为end,步长为steps)的方式索引连续的张量子集。以形状为 [4, 3, 28, 28] 的图片张量为例,在 PyTorch 中图片张量的格式为 [batch_size, channel, width, hight],[4, 3, 28, 28] 的图片张量表示 4 张拥有 RGB 三个通道且每个通道为 (28 x 28) 的像素矩阵。

代码语言:javascript复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取前2张图片
>>> print(a[:2].size())

torch.Size([2, 3, 28, 28])

>>> # 读取前两张图片的R通道的28x28的像素矩阵
>>> print(a[:2, :1, :, :].size())

torch.Size([2, 1, 28, 28])

>>> # 读取前两张图片的GB通道的28x28的像素矩阵
>>> print(a[:2, 1:, :, :].size())

torch.Size([2, 2, 28, 28])

>>> # 读取前两张图片的B通道的28x28的像素矩阵
>>> print(a[:2, -1:, :, :].size())

torch.Size([2, 1, 28, 28])

start: end: step切片方式有很多简写方式,其中 start、end、step 3 个参数可以根据需要选择性的省略,全部省略时即为::,表示从最考试读取到最末尾,步长为 1,即不跳过任何元素。如 x[0,::] 表示读取第一张图片的的所有通道的像素矩阵,其中::表示在通道维度上读取所有RGB三个通道,它等价于 x[0] 的写法。通常为了简洁,将::简写成单个冒号。

代码语言:javascript复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 读取第一张图片
>>> print(a[0,::].size())

torch.Size([3, 28, 28])

>>> # 为了更加简介,::可以简写为单个冒号:
>>> print(a[0,:].size())

torch.Size([3, 28, 28])

接下来总结一下start: end: step 切片的简写方式,其中从第一个元素读取时 start 可以省略,即 start = 0 是可以省略的,取到最后一个元素时 end 可以省略,步长为 1 时 step 可以省略,简写方式总结如表 4.1:

「还有点需要注意,在 PyTorch 中切片索引中的步长不能小于0,即不能为负数。」

代码语言:javascript复制
>>> import torch
>>> # 创建元素值为0~8的1D张量
>>> a = torch.arange(9)
>>> print(a)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

>>> print(a[4: 0: -2])

Traceback (most recent call last):
  File "/home/chenkc/code/tensor.py", line 44, in <module>
    print(a[4: 0: -2])
ValueError: step must be greater than zero

当张量的维度数量较多时,不需要采样的维度一般用单冒号 : 表示采样所有元素,此时有可能出现大量的 : 出现。

代码语言:javascript复制
>>> import torch
>>> # 模拟4张拥有RGB三个通道且每个通道为(28 x 28)的像素矩阵
>>> a = torch.rand(4, 3, 28, 28)
>>> # 获取4张图片的RGB三个通道的所有行和第三列像素矩阵
>>> print(a[:, :, :, 2].size())

torch.Size([4, 3, 28])

「为了避免出现像x[:, :, :, 2] 这样过多冒号的情况,可以使用...符号表示取多个维度上所有数据,其中维度的数量需要根据规则自动推断:当切片方式出现...符号时,...符号左边的维度将自动对齐到最左边,...符号右边的维度将自动对齐到最右边,此时系统再自动推断...符号代表的维度张量,」 它的切片方式总结如表 4.2 所示(「其中表中的···都为...」)。

3. 高级索引

PyTorch 支持绝大多数 NumPy 的高级索引,高级索引可以看成是基本索引的扩展。

代码语言:javascript复制
>>> import torch
>>> a = torch.arange(9).view([3, 3])

>>> print(a)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> print(a[[0, 1],...])

tensor([[0, 1, 2],
        [3, 4, 5]])

>>> print(a[[0, 1], [1, 2]])

tensor([1, 5])

>>> print(a[[1, 0, 2], [0]])

tensor([3, 0, 6])

这里给出了 PyTorch 中的三种高级索引方式,通过这些高级索引的输出结果,可以看出这些高级索引的本质。

  • a[[0, 1, ...]] 等价 a[0] 和 a[1],相当于索引张量的第一行和第二行元素;
  • a[[0, 1, 1, 2]] 等价 a[0, 1] 和 a[1, 2],相当于索引张量的第一行的第二列和第二行的第三列元素;
  • a[[1, 0, 2, 0]] 等价 a[1, 0] 和 a[0, 0] 和 a[2, 0],相当于索引张量的第二行第一列的元素、张量第一行和第一列的元素以及张量第三行和第一列的元素;

References:

1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

2. 初探Numpy中的花式索引

原文地址:https://mp.weixin.qq.com/s?

0 人点赞