一、Batch Normalization是什么?
Batch Normalization (BN) 是最早出现的,也通常是效果最好的归一化方式。feature map:
包含 N 个样本,每个样本通道数为 C,高为 H,宽为 W。
对其求均值和方差时,将在 N、H、W上操作,而保留通道 C 的维度。具体来说,就是把第1个样本的第1个通道,加上第2个样本第1个通道 ...... 加上第 N 个样本第1个通道,求平均,得到通道 1 的均值(注意是除以 N×H×W 而不是单纯除以 N,最后得到的是一个代表这个 batch 第1个通道平均值的数字,而不是一个 H×W 的矩阵)。
求通道 1 的方差也是同理。对所有通道都施加一遍这个操作,就得到了所有通道的均值和方差。具体公式为:
如果把
类比为一摞书,这摞书总共有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符。BN 求均值时,相当于把这些书按页码一一对应地加起来(例如第1本书第36页,第2本书第36页......),再除以每个页码下的字符总数:N×H×W,因此可以把 BN 看成求“平均书”的操作(注意这个“平均书”每页只有一个字),求标准差时也是同理。
我们可以在 pytorch 下自己写一个 BN ,看看和官方的版本是否一致,以检验上述理解是否正确:
代码语言:javascript复制print('diff={}'.format(diff)) # 差别是 10-5 级的,证明和官方版本基本一致
二、BN的优势与作用
BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度
BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。当学习率设置太高时,会使得参数更新步伐过大,容易出现震荡和不收敛。但是使用BN的网络将不会受到参数数值大小的影响。
BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定
在神经网络中,我们经常会谨慎地采用一些权重初始化方法(例如Xavier)或者合适的学习率来保证网络稳定训练。
BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题
在不使用BN层的时候,由于网络的深度与复杂性,很容易使得底层网络变化累积到上层网络中,导致模型的训练很容易进入到激活函数的梯度饱和区;通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习 与又让数据保留更多的原始信息。 BN具有一定的正则化效果
在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音,与Dropout通过关闭神经元给网络训练带来噪音类似,在一定程度上对模型起到了正则化的效果。另外,原作者通过也证明了网络加入BN后,可以丢弃Dropout,模型也同样具有很好的泛化效果。
三、BN的局限
局限1:如果Batch Size太小,则BN效果明显下降。
BN是严重依赖Mini-Batch中的训练实例的,如果Batch Size比较小则任务效果有明显的下降。那么多小算是太小呢?图10给出了在ImageNet数据集下做分类任务时,使用ResNet的时候模型性能随着BatchSize变化时的性能变化情况,可以看出当BatchSize小于8的时候开始对分类效果有明显负面影响。之所以会这样,是因为在小的BatchSize意味着数据样本少,因而得不到有效统计量,也就是说噪音太大。这个很好理解,这就类似于我们国家统计局在做年均收入调查的时候,正好把你和马云放到一个Batch里算平均收入,那么当你为下个月房租发愁之际,突然听到你所在组平均年薪1亿美金时,你是什么心情,那小Mini-Batch里其它训练实例就是啥心情。
BN的Batch Size大小对ImageNet分类任务效果的影响(From GN论文)
BN的Batch Size大小设置是由调参师自己定的,调参师只要把Batch Size大小设置大些就可以避免上述问题。但是有些任务比较特殊,要求batch size必须不能太大,在这种情形下,普通的BN就无能为力了。比如BN无法应用在Online Learning中,因为在线模型是单实例更新模型参数的,难以组织起Mini-Batch结构。
局限2:对于有些像素级图片生成任务来说,BN效果不佳;
对于图片分类等任务,只要能够找出关键特征,就能正确分类,这算是一种粗粒度的任务,在这种情形下通常BN是有积极效果的。但是对于有些输入输出都是图片的像素级别图片生成任务,比如图片风格转换等应用场景,使用BN会带来负面效果,这很可能是因为在Mini-Batch内多张无关的图片之间计算统计量,弱化了单张图片本身特有的一些细节信息。
局限3:RNN等动态网络使用BN效果不佳且使用起来不方便
对于RNN来说,尽管其结构看上去是个静态网络,但在实际运行展开时是个动态网络结构,因为输入的Sequence序列是不定长的,这源自同一个Mini-Batch中的训练实例有长有短。对于类似RNN这种动态网络结构,BN使用起来不方便,因为要应用BN,那么RNN的每个时间步需要维护各自的统计量,而Mini-Batch中的训练实例长短不一,这意味着RNN不同时间步的隐层会看到不同数量的输入数据,而这会给BN的正确使用带来问题。假设Mini-Batch中只有个别特别长的例子,那么对较深时间步深度的RNN网络隐层来说,其统计量不方便统计而且其统计有效性也非常值得怀疑。另外,如果在推理阶段遇到长度特别长的例子,也许根本在训练阶段都无法获得深层网络的统计量。综上,在RNN这种动态网络中使用BN很不方便,而且很多改进版本的BN应用在RNN效果也一般。
局限4:训练时和推理时统计量不一致
对于BN来说,采用Mini-Batch内实例来计算统计量,这在训练时没有问题,但是在模型训练好之后,在线推理的时候会有麻烦。因为在线推理或预测的时候,是单实例的,不存在Mini-Batch,所以就无法获得BN计算所需的均值和方差,一般解决方法是采用训练时刻记录的各个Mini-Batch的统计量的数学期望,以此来推算全局的均值和方差,在线推理时采用这样推导出的统计量。虽说实际使用并没大问题,但是确实存在训练和推理时刻统计量计算方法不一致的问题。
上面所列BN的四大罪状,表面看是四个问题,其实深入思考,都指向了幕后同一个黑手,这个隐藏在暗处的黑手是谁呢?就是BN要求计算统计量的时候必须在同一个Mini-Batch内的实例之间进行统计,因此形成了Batch内实例之间的相互依赖和影响的关系。如何从根本上解决这些问题?一个自然的想法是:把对Batch的依赖去掉,转换统计集合范围。在统计均值方差的时候,不依赖Batch内数据,只用当前处理的单个训练数据来获得均值方差的统计量,这样因为不再依赖Batch内其它训练数据,那么就不存在因为Batch约束导致的问题。在BN后的几乎所有改进模型都是在这个指导思想下进行的。
但是这个指导思路尽管会解决BN带来的问题,又会引发新的问题,新的问题是:我们目前已经没有Batch内实例能够用来求统计量了,此时统计范围必须局限在一个训练实例内,一个训练实例看上去孤零零的无依无靠没有组织,怎么看也无法求统计量,所以核心问题是对于单个训练实例,统计范围怎么算?
四、BN的折叠优化
折叠Batch Normalization,也叫作折叠BN。我们知道一般BN是跟在卷积层后面,一般还会接上激活函数,也就是conv BN relu这种基本组件,但在部署的时候前向推理框架一般都会自动的将BN和它前面的卷积层折叠在一起,实现高效的前向推理网络。
我们知道卷积层的计算可以表示为:
然后BN层的计算可以表示为:
我们把二者组合一下,公式如下:
然后令
那么,合并BN层后的卷积层的权重和偏置可以表示为:
值得一提的是,一般Conv后面接BN的时候很多情况下是不带Bias的,这个时候上面的公式就会少第二项。
由于完整代码太长,完整可以参考这个工程:https://github.com/BBuf/cv_tools/blob/master/merge_bn.py
代码语言:javascript复制def fuse(conv, bn):
global i
i = i 1
# ********************BN参数*********************
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var bn.eps)
gamma = bn.weight
beta = bn.bias
# *******************conv参数********************
w = conv.weight
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
if(i >= 2 and i <= 7):
b = b - mean beta * var_sqrt
else:
w = w * (gamma / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * gamma beta
fused_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
groups=conv.groups,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv