增加维度
增加一个长度为 1 的维度相当于给原有的张量添加一个新维度的概念。由于增加的新维度长度为 1,因此张量中的元素并没有发生改变,仅仅改变了张量的理解方式。比如一张 大小的灰度图片保存为形状为 的张量,在张量的头部增加一个长度为 1 的新维度,定义为通道数维度,此时张量的形状为 。
“图片张量的形状有两种约定:
- 通道在后的约定。TensorFlow 将通道维度放在最后: ;
- 通道在前的约定。PyTorch 将通道维度放在前面:
”
使用 torch.unsqueeze(input, dim)
可以在指定的 dim 维度前插入一个长度为 1 的新维度。
>>> import torch
>>> # 使用随机生成的正态分布模拟没有通道维度的图片张量
>>> input = torch.randn(28, 28)
>>> print(input.size())
torch.Size([28, 28])
>>> # 指定第0个维度前面插入新的维度
>>> image = torch.unsqueeze(input, dim = 0)
>>> print(image.size())
torch.Size([1, 28, 28])
「需要注意的是,torch.unsqueeze(input, dim)
的 dim 参数既可以为正整数也可以为负整数:」
- 当 dim 为正整数时,表示在当前维度之前插入一个长度为 1 的新维度;
- 当 dim 为负整数时,表示在当前维度之后插入一个长度为 1 的新维度;
以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际插入位置如下所示。
通过上图可以看出,无论 dim 参数值是正整数还是负整数,其具体范围都和输入张量的维度有关。对于输入张量为 的图片张量而言,张量的维度为 4,其 dim 参数的取值范围为 ,对比不同维度的输入张量:
- 输入张量的维度 input.dim() = 2 时,dim 参数的取值范围为
- 输入张量的维度 input.dim() = 3 时,dim 参数的取值范围为
得到 dim 参数的取值范围为 ,其中 input.dim() 为输入张量的维度。
如果指定 dim 参数超过其取值范围,会抛出 IndexError。
代码语言:javascript复制>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())
torch.Size([1, 1, 28, 28])
>>> print(input.dim())
4
>>> # input.dim() = 4
>>> # [-4-1, 4 1) = [-5, 5)
>>> # 将dim设置为5,超出dim参数的取值范围
>>> # x = torch.unsqueeze(input, dim = 5) error
>>> # print(x.size())
Traceback (most recent call last):
File "/home/chenkc/code/pytorch/test_02.py", line 19, in <module>
x = torch.unsqueeze(input, dim = 5)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
删除维度
删除维度是增加维度的逆操作,与增加维度一样,「删除维度只能删除长度为 1 的维度,同时也不会改变张量的存储」。对于形状为 的张量来说,如果希望将批量维度删除 (batch_size 通常称为批量维度),可以通过 torch.squeeze(input, dim)
函数,「dim 参数为待删除维度的索引号。」
例如,删除形状为 图片张量的批量维度。
代码语言:javascript复制>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())
torch.Size([1, 1, 28, 28])
>>> # squeeze函数中dim参数为待删除维度的索引号
>>> # [b,c,h,w]中批量维度的索引为0
>>> x = torch.squeeze(input, dim = 0)
>>> print(x.size())
torch.Size([1, 28, 28])
与增加维度的 torch.unsqueeze(input, dim)
中 dim 参数不同,在 torch.squeeze(input, dim)
中 dim 参数表示待删除维度的索引号。同样以 张量为例 (为了方便叙述将其简写成 ),不同 dim 参数的实际删除的维度如下所示。
如果不指定维度参数 dim,即 torch.squeeze(input)
,它会默认的删除所有长度为 1 的维度。
>>> import torch
>>> # 使用随机生成的正态分布模拟[b,c,h,w]
>>> input = torch.randn(1, 1, 28, 28)
>>> print(input.size())
torch.Size([1, 1, 28, 28])
>>> # 不指定dim参数默认删除所有长度为1的唯独
>>> x = torch.squeeze(input)
>>> print(x.size())
torch.Size([28, 28])
小结
Tips: 在 torch.squeeze(input, dim)
函数中,如果不指定维度参数 dim,即 dim = None 时,它默认会删除输入张量中所有长度为 1 的维度。
References:
- 《TensorFlow深度学习》