前言
交换维度顾名思义就是交换不同的维度,线性代数中矩阵的转置操作可以看成是交换第 0 个和第 1 个维度。比如下图形状为 (3, 4) 的矩阵。
交换第 0 个维度和第 1 个维度 (转置) 为形状为 (4, 3) 的矩阵。
不仅是在线性代数中经常会遇到交换维度的操作,在深度学习中交换维度的操作也非常常见。比如对于图片张量来说,在 PyTorch 中将通道维度放到最后面 [b, h, w, c],而在 TensorFlow 中将通道维度放在前面 [b, c, h, w],如果需要将 [b, h, w, c] 转换为 [b, c, h, w] 则需要使用交换维度的操作。
PyTorch 中交换维度的操作有 transpose 和 permute 两种方式。交换维度的操作至少要求张量拥有两个以及两个以上的维度才有意义,因此在介绍交换维度的方式时不再考虑 0D 和 1D 张量。
transpose
torch.transpose(input, dim0, dim1)
函数将输入张量 input 的第 dim0 个维度和第 dim1 个维度进行交换,并将交换维度后的张量返回。transpose 函数作用非常直观,使用起来也非常简单,因此使用方法不再过多的赘述。
下面是在使用 transpose 函数时的几个注意事项。
- transpose 函数能够交换 nD 张量 () 的任意两个不同的维度 (交换相同的维度没有意义);
- transpose 函数中的三个参数都是必选参数。换句话说,如果不为三个参数都指定具体的值,代码会抛出异常;
- 交换维度后的张量与原始张量共享内存。换句话说,如果修改了交换维度后的张量,原始张量也会发生对应的改变;
- 由于 2D 张量仅有两个维度,交换维度的操作固定,类似对矩阵进行转置操作,因此 PyTorch 提供了一个更方便的方法
torch.t(input)
。当 input 为 2D 张量时torch.t(input)
等价torch.transpose(input, 0, 1)
(或torch.transpose(input, 1, 0)
);
交换 nD 张量 ( n geq 2) 的任意两个不同的维度在很多时候并不能满足我们的需求。比如将图片张量 [b, h, w, c] 转换为 [b, c, h, w]。
代码语言:txt复制>>> import torch
>>> # 使用[0, 1)均匀分布模拟图片张量
>>> # (batch_size, height, width, channels)
>>> imgs = torch.randn([1, 32, 28, 3])
>>> # 交换height和channels两个维度
>>> imgs_swap = torch.transpose(imgs, 1, 3)
>>> # (batch_size, channels, width, height)
>>> print(imgs_swap.shape)
torch.Size([1, 3, 28, 32])
>>> # 交换width和height两个维度
>>> imgs = torch.transpose(imgs_swap, 2, 3)
>>> # (batch_size, channels, height, width)
>>> print(imgs.shape)
torch.Size([1, 3, 32, 28])
虽然能够将图片张量 [b, h, w, c] 转换为 [b, c, h, w],但是使用了两次 transpose 函数,并且需要熟知每次变换后对应维度的位置,非常容易出错。PyTorch 针对这种多次交换维度的方式提供 permute 函数。
permute
前面提到过 PyTorch 从接口的角度将张量的操作分成两种方式。比如对于 transpose 函数来说,可以使用 torch.transpose(input, dim0, dim1)
或者 input.transpose(dim0, dim1)
,两种定义方式本质上是一样的。但是 permute 函数只有 input.permute(*dims)
一种定义方式,其中 *dims 为期望维度的顺序。 来看看如何通过 permute 函数将图片张量 [b, h, w, c] 转换为 [b, c, h, w]。
>>> import torch
>>> # 使用[0, 1)均匀分布模拟图片张量
>>> # (batch_size, height, width, channels)
>>> imgs = torch.randn([1, 32, 28, 3])
>>> # 重新排列维度顺序
>>> print(imgs.permute(0, 3, 1, 2).shape)
torch.Size([1, 3, 32, 28])
[b, h, w, c] 维度序号为 (0, 1, 2, 3)
,如果想要将 [b, h, w, c] 转换为 [b, c, h, w],只需要重新排列一下维度序号 (0, 3, 1, 2)
,这也是 permute 函数的设计原理。
原文地址:
PyTorch入门笔记-交换维度