pytorch维度变换-补充知识

2019-11-17 23:08:28 浏览数 (1)

补充介绍一下转置操作

先建立矩阵a,分别输出a和a的转置矩阵

代码语言:javascript复制
a = torch.randn(3, 4)
print(a)
print(a.t())

代码语言:javascript复制
tensor([[-0.4018, -1.4217,  0.5778, -1.0832],
        [ 0.9451,  0.2730,  0.2420,  1.3747],
        [-1.3293,  1.5332, -1.1212,  0.8263]])
tensor([[-0.4018,  0.9451, -1.3293],
        [-1.4217,  0.2730,  1.5332],
        [ 0.5778,  0.2420, -1.1212],
        [-1.0832,  1.3747,  0.8263]])

需要注意的是转置功能只适用于2D的矩阵,而不适用于3D或4D的矩阵。

代码语言:javascript复制
a = torch.randn([3, 4, 3, 3])
print(a.t())

此时输出会报错

代码语言:javascript复制
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

除了.t()方法外,还可以使用.transpose(d1, d2)函数。在使用时需要给d1

d2赋值,以给出调换的位置。

代码语言:javascript复制
a = torch.randn([4, 3, 28, 28])
b = a.transpose(1, 3).view(4, 3*28*28).view(4, 3, 28, 28)
# 将原来的[b, c, h, w]=>[b, w, h, c]后,再将后面三个维度连在一起来理解,再展开成[b, c, w, h]
print(b)

此时输出会报错

代码语言:javascript复制
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at ..atensrcTH/generic/THTensor.cpp:203

报错原因在于view函数会破坏原来的元素顺序,展开时channel元素跑到了前面。因此在使用transpose和view函数时,要格外注意数据的维度顺序和存储顺序需保持一致。

这里可以使用.contiguous函数,将数据重新变成连续。

代码语言:javascript复制
b = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 3, 28, 28)
print('b=', b.shape)

输出

代码语言:javascript复制
b= torch.Size([4, 3, 28, 28])

但这里有一个问题,通过以上转换,矩阵经历了[b,c,h,w]=>[b,w,h,c]=>[b,c,w,h],这样虽然数据连续了,但这种转换方式会造成数据污染。

这里再介绍法2

代码语言:javascript复制
c = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 28, 28, 3).transpose(1, 3)
# 以上经历了[b,c,h,w]=>[b,w,h,c]=>[b,w,h,c]=>[b,c,h,w]
print('c=', c.shape)

输出

代码语言:javascript复制
c= torch.Size([4, 3, 28, 28])

以上两种方法虽然输出均为一致,但为验证有没有数据污染,使用torch.eq函数进行分析

代码语言:javascript复制
print(torch.all(torch.eq(a, b)))
# 添加torch.all,确保所有数据均一致
print(torch.all(torch.eq(a, c)))

输出

代码语言:javascript复制
tensor(0, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)

返回0说明数据不一致,返回1说明数据一致。由此看出b虽然各数据维度与a相同,但已造成了数据污染,而c没有数据污染。

下面介绍一种更加方便的转置API: permute

与transpose每次只能两两交换不同的是,permute可以一次性给出四个维度上的位置。

.permute(d1, d2, d3, d4) 通过输入d1、d2、d3、d4的顺序即可完成。

如原[b,c,h,w]想要变成[b,h,w,c],只要输入.permute(0, 2, 3, 1)即可实现,而.transpose需要重复两两调换好几次。

0 人点赞