补充介绍一下转置操作
先建立矩阵a,分别输出a和a的转置矩阵
代码语言:javascript复制a = torch.randn(3, 4)
print(a)
print(a.t())
代码语言:javascript复制tensor([[-0.4018, -1.4217, 0.5778, -1.0832],
[ 0.9451, 0.2730, 0.2420, 1.3747],
[-1.3293, 1.5332, -1.1212, 0.8263]])
tensor([[-0.4018, 0.9451, -1.3293],
[-1.4217, 0.2730, 1.5332],
[ 0.5778, 0.2420, -1.1212],
[-1.0832, 1.3747, 0.8263]])
需要注意的是转置功能只适用于2D的矩阵,而不适用于3D或4D的矩阵。
代码语言:javascript复制a = torch.randn([3, 4, 3, 3])
print(a.t())
此时输出会报错
代码语言:javascript复制RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
除了.t()方法外,还可以使用.transpose(d1, d2)函数。在使用时需要给d1
d2赋值,以给出调换的位置。
代码语言:javascript复制a = torch.randn([4, 3, 28, 28])
b = a.transpose(1, 3).view(4, 3*28*28).view(4, 3, 28, 28)
# 将原来的[b, c, h, w]=>[b, w, h, c]后,再将后面三个维度连在一起来理解,再展开成[b, c, w, h]
print(b)
此时输出会报错
代码语言:javascript复制RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at ..atensrcTH/generic/THTensor.cpp:203
报错原因在于view函数会破坏原来的元素顺序,展开时channel元素跑到了前面。因此在使用transpose和view函数时,要格外注意数据的维度顺序和存储顺序需保持一致。
这里可以使用.contiguous函数,将数据重新变成连续。
代码语言:javascript复制b = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 3, 28, 28)
print('b=', b.shape)
输出
代码语言:javascript复制b= torch.Size([4, 3, 28, 28])
但这里有一个问题,通过以上转换,矩阵经历了[b,c,h,w]=>[b,w,h,c]=>[b,c,w,h],这样虽然数据连续了,但这种转换方式会造成数据污染。
这里再介绍法2
代码语言:javascript复制c = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 28, 28, 3).transpose(1, 3)
# 以上经历了[b,c,h,w]=>[b,w,h,c]=>[b,w,h,c]=>[b,c,h,w]
print('c=', c.shape)
输出
代码语言:javascript复制c= torch.Size([4, 3, 28, 28])
以上两种方法虽然输出均为一致,但为验证有没有数据污染,使用torch.eq函数进行分析
代码语言:javascript复制print(torch.all(torch.eq(a, b)))
# 添加torch.all,确保所有数据均一致
print(torch.all(torch.eq(a, c)))
输出
代码语言:javascript复制tensor(0, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)
返回0说明数据不一致,返回1说明数据一致。由此看出b虽然各数据维度与a相同,但已造成了数据污染,而c没有数据污染。
下面介绍一种更加方便的转置API: permute
与transpose每次只能两两交换不同的是,permute可以一次性给出四个维度上的位置。
.permute(d1, d2, d3, d4) 通过输入d1、d2、d3、d4的顺序即可完成。
如原[b,c,h,w]想要变成[b,h,w,c],只要输入.permute(0, 2, 3, 1)即可实现,而.transpose需要重复两两调换好几次。