pytorch基础知识-维度变换-(下)

2019-11-17 23:11:52 浏览数 (2)

接下来介绍维度挤压

squeeze的用法与unsqueeze类似,同样需要给出要操作的维度参数,但若不给出维度的话,会把所有能删减的维度都去掉。

先设定b为[1, 32, 1, 1]

代码语言:javascript复制
import torch
b = torch.rand([1, 32, 1, 1])

进行如下维度变换操作

代码语言:javascript复制
print(b.shape)
print(b.squeeze().shape)
print(b.squeeze(0).shape)
print(b.squeeze(-1).shape)

依次输出为

代码语言:javascript复制
torch.Size([1, 32, 1, 1])
torch.Size([32])
torch.Size([32, 1, 1])
torch.Size([1, 32, 1])

下面介绍维度扩展

Expand: broadcasting

Repeat: memory copied

两者均可实现维度扩展。但区别在于第一种只是改变了理解方式,并没有增加数据。而第二种是实实在在的增加数据(拷贝数据的方式)。第一种方式只是在有需要的时候才会扩展,因此节约了内存,故推荐第一种。

举例

先设定需要扩展的

代码语言:javascript复制
b = torch.rand([1, 32, 1, 1])
# b为需要扩展的
a = torch.rand(4, 32, 14, 14)
# a为b要经过扩展后成为的目标

进行扩展时,使用.expand即可

下面分别输入两段代码

代码语言:javascript复制
print(b.expand(4, 32, 14, 14).shape)
print(b.shape)

输出分别为

代码语言:javascript复制
torch.Size([4, 32, 14, 14])
torch.Size([1, 32, 1, 1])

由此可见b本质并没有增加,只是在程序需要时才扩展了维度。另外需要注意的是他们对应的dim均需相同,且只能是1=>N的扩展,不能是其他数字=>N的扩展。

另外

代码语言:javascript复制
print(b.expand(-1, 32, -1, 14).shape)
# 括号内为-1时,表明保持原信息不动

输出

代码语言:javascript复制
torch.Size([1, 32, 1, 14])

而repeat函数作用的方式则与expand不同,

代码语言:javascript复制
b = torch.rand([1, 32, 1, 1])
c = b.repeat(4, 32, 1, 1)
# 这里的每一个维度上的数值代表了原数值需要复制的次数
print(c.shape)

输出

代码语言:javascript复制
torch.Size([4, 1024, 1, 1])

做了[4*1, 32*32, 1*1, 1*1]的运算

类似的

代码语言:javascript复制
b = torch.rand([1, 32, 1, 1])
c = b.repeat(4, 16, 32, 1)
print(c.shape)

输出

代码语言:javascript复制
torch.Size([4, 512, 32, 1])

这里不推荐使用repeat操作,因为使用时会使内存里的数据急剧增加。

0 人点赞