PyTorch入门笔记-增删张量的维度

2021-01-03 15:55:49 浏览数 (1)

增加维度

增加一个长度为 1 的维度相当于给原有的张量添加一个新维度的概念。由于增加的新维度长度为 1,因此张量中的元素并没有发生改变,仅仅改变了张量的理解方式。比如一张 大小的灰度图片保存为形状为 的张量,在张量的头部增加一个长度为 1 的新维度,定义为通道数维度,此时张量的形状为 。

“图片张量的形状有两种约定:

  • 通道在后的约定。TensorFlow 将通道维度放在最后: ;
  • 通道在前的约定。PyTorch 将通道维度放在前面:

使用 torch.unsqueeze(input, dim) 可以在指定的 dim 维度前插入一个长度为 1 的新维度。

代码语言:javascript复制
>>> import torch
>>> # 使用随机生成的正态分布模拟没有通道维度的图片张量
>>> input = torch.randn(28, 28)
>>> print(input.size())

torch.Size([28, 28])

>>> # 指定第0个维度前面插入新的维度
>>> image = torch.unsqueeze(input, dim = 0)
>>> print(image.size())

torch.Size([1, 28, 28])

「需要注意的是,torch.unsqueeze(input, dim) 的 dim 参数既可以为正整数也可以为负整数:」

  • 当 dim 为正整数时,表示在当前维度之前插入一个长度为 1 的新维度;
  • 当 dim 为负整数时,表示在当前维度之后插入一个长度为 1 的新维度;

以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际插入位置如下所示。

通过上图可以看出,无论 dim 参数值是正整数还是负整数,其具体范围都和输入张量的维度有关。对于输入张量为 的图片张量而言,张量的维度为 4,其 dim 参数的取值范围为 ,对比不同维度的输入张量:

  • 输入张量的维度 input.dim() = 2 时,dim 参数的取值范围为
  • 输入张量的维度 input.dim() = 3 时,dim 参数的取值范围为

得到 dim 参数的取值范围为 ,其中 input.dim() 为输入张量的维度。

如果指定 dim 参数超过其取值范围,会抛出 IndexError。

代码语言:javascript复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> print(input.dim())

4

>>> # input.dim() = 4
>>> # [-4-1, 4 1) = [-5, 5)
>>> # 将dim设置为5,超出dim参数的取值范围
>>> # x = torch.unsqueeze(input, dim = 5) error
>>> # print(x.size())

Traceback (most recent call last):
  File "/home/chenkc/code/pytorch/test_02.py", line 19, in <module>
    x = torch.unsqueeze(input, dim = 5)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

删除维度

删除维度是增加维度的逆操作,与增加维度一样,「删除维度只能删除长度为 1 的维度,同时也不会改变张量的存储」。对于形状为 的张量来说,如果希望将批量维度删除 (batch_size 通常称为批量维度),可以通过 torch.squeeze(input, dim) 函数,「dim 参数为待删除维度的索引号。」

例如,删除形状为 图片张量的批量维度。

代码语言:javascript复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> # squeeze函数中dim参数为待删除维度的索引号
>>> # [b,c,h,w]中批量维度的索引为0
>>> x = torch.squeeze(input, dim = 0)
>>> print(x.size())

torch.Size([1, 28, 28])

与增加维度的 torch.unsqueeze(input, dim) 中 dim 参数不同,在 torch.squeeze(input, dim) 中 dim 参数表示待删除维度的索引号。同样以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际删除的维度如下所示。

如果不指定维度参数 dim,即 torch.squeeze(input),它会默认的删除所有长度为 1 的维度。

代码语言:javascript复制
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())

torch.Size([1, 1, 28, 28])

>>> # 不指定dim参数默认删除所有长度为1的唯独
>>> x = torch.squeeze(input)
>>> print(x.size())

torch.Size([28, 28])

小结

Tips: 在 torch.squeeze(input, dim) 函数中,如果不指定维度参数 dim,即 dim = None 时,它默认会删除输入张量中所有长度为 1 的维度。

References:

  1. 《TensorFlow深度学习》

0 人点赞