深度学习基础:6.Batch Normalization简介/作用

2022-07-14 15:04:48 浏览数 (1)

数据标准化

由于Batch Normalization包含数据标准化的操作,因此在了解BN前,首先要对数据标准化有个简单认识。 数据标准化通常包括两种:0-1标准化Z-score标准化,深度学习中的标准化往往指代的是后者。

0-1标准化

0-1标准化的公式如下:

Z-score标准化

Z-score标准化的公式如下:

数据通过Z-Score处理之后将服从标准正态分布。

BN背景

BN是由Sergey loffe和Christian Szegedy在2015年发表的《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》中提出的一种方法,该方法通过修改每一次带入训练的数据分布(每一个Batch)的数据分布,来提升模型各层梯度的平稳性,从而提升模型学习效率、提高模型训练结果。由于是修改每一个Batch的数据分布,因此该方法也被称为Batch Normalization(BN),小批量数据归一化方法。

对于BN可以提升模型性能方面,作者表示根本原因是因为BN能够消除内部协方差偏移(ICS),然而2018年,来自MIT的研究团队发表论文,《How Does Batch Normalization Help Optimization?》表示,通过严谨的实验可以证明,BN方法对模型优化的有效性和原论文中所描述的消除ICS没有任何关系。

BN计算公式

BN计算公式如下:

计算过程如下: 1.计算样本均值。 2.计算样本方差。 3.样本数据标准化处理。 4.进行平移和缩放处理。

注:第四步中的

gamma

beta

都是模型参数,在实际模型训练过程中是需要作为模型整体参数的一部分,带入损失函数、进而通过梯度下降计算得出的。

BN作用

目前,大多数的主流模型都使用了BN操作,说明BN非常实用。BN主要有三个作用:

  • 1.加快网络的训练和收敛的速度 这个作用比较容易理解,因为在BN的计算过程中,把数据进行了标准化,转化成了标准正态分布(均值为0,方差为1),因此这样的数据会更容易收敛。
  • 2.控制梯度爆炸和梯度消失 梯度爆炸主要是因为网络加深之后需要对多层进行链式求导,如果输出过大,则会导致梯度爆炸。梯度消失相反,输出太小,导致相乘之后更小,接近0。使用BN层标准化数据后,网络的输出就不会很大,梯度就不会很小,因此能控制梯度爆炸和梯度消失。
  • 3.防止过拟合 BN的使用使得一个batch中的样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,即同样一个样本的输出不再仅仅取决于样本的本身,也取决于跟这个样本同属一个batch的其他样本。每次网络都是随机取batch,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合。

通常来说,BN层放置在线性层和卷积层之后,激活函数之前。

BN编程实践

在Pytorch中,提供了两个BN实现接口:BatchNorm1d用来处理1维数据,BatchNorm2d用来处理2维数据。 BatchNorm1dBatchNorm2d有下列几个参数

代码语言:javascript复制
nn.BatchNorm1d(
    num_features,
    eps=1e-05,
    momentum=0.1,
    affine=True,
    track_running_stats=True,
    device=None,
    dtype=None,
)

核心参数是:

  • num_features:输入数据的特征数量,也就是前一层神经元数量或原始数据集特征数量。
  • eps:方差分母修正项,为了防止分母为0,一般取值为1e-5,也就是默认值。
  • affine:是否进行仿射变换,需要注意的是,此时进行仿射变换时将使用无偏估计进行期望和方差的计算,并且初始条件下
gamma=1,beta=0

,当参数取值为True时,会显式设置

gamma和beta

参数并带入进行梯度下降迭代计算,取值为False时,参数不显示,实际的数据归一化过程就是对原数据进行无偏估计下的Z-Score变换。

具体实例就先不放了,后续在图像处理中会频繁用。

0 人点赞