本篇文章将要总结下Pytorch常用的一些张量操作,并说明其作用,接着使用这些操作实现归一化操作的算法,如BN,GN,LN,IN等!
1
Pytorch中常用张量操作
torch.cat
对数据沿着某一维度进行拼接,cat后的总维度数不变,需要注意两个张量进行cat时某一维的维数要相同,否则会报错!
代码语言:javascript复制import torch
x = torch.randn(,)
y = torch.randn(,)
torch.cat((x, y), ) # 维度为(3, 3)
z = torch.randn(, )
torch.cat((x, z), ) # 报错
stack
相比于Cat,Stack则会增加新的维度,并且将两个矩阵在新的维度上进行堆叠,一般要求两个矩阵的维度是相同的!
代码语言:javascript复制import torch
x = torch.randn(,)
y = torch.randn(,)
torch.stack((x, y), ) # 在0维度进行堆叠,维度为(2, 1, 2)
torch.stack((x, y), ) # 维度为(1, 2, 2)
transpose
其作用为交换两个维度,类似于二维矩阵的转置作用!
代码语言:javascript复制import torch
x = torch.randn(,)
x.transpose(, ) # 维度为(3, 2)
permute
其相当于增强版的transpose,适合于多维数据,更加灵活一点!
代码语言:javascript复制import torch
x = torch.randn(,,,)
x_p = x.permute(,,,) # 维度变为(2,1,3,4)
squeeze和unsqueeze
squeeze(dim)为压缩的意思,即去掉维度数为1的dim,默认是去掉所有为1的,当然也可以自己指定,但如果指定的维度数不为1,则不会发生任何改变。unsqueeze(dim)则与squeeze(dim)正好相反,为添加一个维度的作用。
代码语言:javascript复制import torch
x = torch.randn(,)
x.squeeze() # 维度(2,)
x.squeeze() # 维度(2,)
x.unsqueeze() # 维度(2,1,1)
x.unsqueeze() # 维度(1,2,1)
view、contigous和reshape
有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式。特别的在Pytorch0.4中,在使用了permute和transpose后,内存就不连续了,因此不能直接使用view函数,应该先contigous()变成连续内存后,再使用view。 Pytorch0.4中,增加了一个reshape函数,就相当于contigous().view()的功能了!
2
归一化操作的实现
我们今天只来考虑如何实现,至于归一化的原理我们就不再赘述,知乎和博客都写的很多了,对于这几种归一化的方法,比如BN(Batch),LN(Layer),IN(Instance),GN(Group)这四种,在GN的论文中有一幅图可以清晰的描述,我们不用看公式,只要把下面这个图记住就好了!(蓝色区域即为其归一化的区域,说白了我们每个归一化时使用的均值和方差就是由蓝色区域计算得来的,然后作用到这个蓝色区域进行归一化,从而对整体X进行归一化)。
那么我们可以看下简单实现(仅归一化)
Batch Normalization
代码语言:javascript复制import torch
from torch import nn
bn = nn.BatchNorm2d(num_features=, eps=, affine=False, track_running_stats=False)
x = torch.rand(, , , )*
official_bn = bn(x) # 官方代码
x1 = x.permute(, , , ).reshape(, -1) # 对(N, H, W)计算均值方差
mean = x1.mean(dim=).reshape(, , , )
# x1.mean(dim=1)后维度为(3,)
std = x1.std(dim=, unbiased=False).reshape(, , , )
my_bn = (x - mean)/std
print((official_bn-my_bn).sum()) # 输出误差
Layer Normalization
代码语言:javascript复制import torch
from torch import nn
ln = nn.LayerNorm(normalized_shape=[, , ], eps=, elementwise_affine=False)
x = torch.rand(, , , )*
official_ln = ln(x) # 官方代码
x1 = x.reshape(, -1) # 对(C,H,W)计算均值方差
mean = x1.mean(dim=).reshape(, , , )
std = x1.std(dim=, unbiased=False).reshape(, , , )
my_ln = (x - mean)/std
print((official_ln-my_ln).sum())
Instance Normalization
代码语言:javascript复制import torch
from torch import nn
In = nn.InstanceNorm2d(num_features=, eps=, affine=False, track_running_stats=False)
x = torch.rand(, , , )*
official_In = In(x) # 官方代码
x1 = x.reshape(, -1) # 对(H,W)计算均值方差
mean = x1.mean(dim=).reshape(, , , )
std = x1.std(dim=, unbiased=False).reshape(, , , )
my_In = (x - mean)/std
print((official_In-my_In).sum())
Group Normalization
代码语言:javascript复制import torch
from torch import nn
gn = nn.GroupNorm(num_groups=, num_channels=, eps=, affine=False)
# 分成了4组,也就是说蓝色区域为(5,5, 5)
x = torch.rand(, , , )*
official_gn = gn(x) # 官方代码
x1 = x.reshape(,,-1) # 对(H,W)计算均值方差
mean = x1.mean(dim=).reshape(, , -1)
std = x1.std(dim=, unbiased=False).reshape(, , -1)
my_gn = ((x1 - mean)/std).reshape(, , , )
print((official_gn-my_gn).sum())
以上代码参考并修改自知乎专栏文章(https://zhuanlan.zhihu.com/p/69659844)
完