接下来介绍维度挤压
squeeze的用法与unsqueeze类似,同样需要给出要操作的维度参数,但若不给出维度的话,会把所有能删减的维度都去掉。
先设定b为[1, 32, 1, 1]
代码语言:javascript复制import torch
b = torch.rand([1, 32, 1, 1])
进行如下维度变换操作
代码语言:javascript复制print(b.shape)
print(b.squeeze().shape)
print(b.squeeze(0).shape)
print(b.squeeze(-1).shape)
依次输出为
代码语言:javascript复制torch.Size([1, 32, 1, 1])
torch.Size([32])
torch.Size([32, 1, 1])
torch.Size([1, 32, 1])
下面介绍维度扩展
Expand: broadcasting
Repeat: memory copied
两者均可实现维度扩展。但区别在于第一种只是改变了理解方式,并没有增加数据。而第二种是实实在在的增加数据(拷贝数据的方式)。第一种方式只是在有需要的时候才会扩展,因此节约了内存,故推荐第一种。
举例
先设定需要扩展的
代码语言:javascript复制b = torch.rand([1, 32, 1, 1])
# b为需要扩展的
a = torch.rand(4, 32, 14, 14)
# a为b要经过扩展后成为的目标
进行扩展时,使用.expand即可
下面分别输入两段代码
代码语言:javascript复制print(b.expand(4, 32, 14, 14).shape)
print(b.shape)
输出分别为
代码语言:javascript复制torch.Size([4, 32, 14, 14])
torch.Size([1, 32, 1, 1])
由此可见b本质并没有增加,只是在程序需要时才扩展了维度。另外需要注意的是他们对应的dim均需相同,且只能是1=>N的扩展,不能是其他数字=>N的扩展。
另外
代码语言:javascript复制print(b.expand(-1, 32, -1, 14).shape)
# 括号内为-1时,表明保持原信息不动
输出
代码语言:javascript复制torch.Size([1, 32, 1, 14])
而repeat函数作用的方式则与expand不同,
代码语言:javascript复制b = torch.rand([1, 32, 1, 1])
c = b.repeat(4, 32, 1, 1)
# 这里的每一个维度上的数值代表了原数值需要复制的次数
print(c.shape)
输出
代码语言:javascript复制torch.Size([4, 1024, 1, 1])
做了[4*1, 32*32, 1*1, 1*1]的运算
类似的
代码语言:javascript复制b = torch.rand([1, 32, 1, 1])
c = b.repeat(4, 16, 32, 1)
print(c.shape)
输出
代码语言:javascript复制torch.Size([4, 512, 32, 1])
这里不推荐使用repeat操作,因为使用时会使内存里的数据急剧增加。