标签平滑 Label Smoothing 详解及 pytorch tensorflow实现

2022-01-27 16:24:55 浏览数 (2)

定义

标签平滑(Label smoothing),像L1、L2和dropout一样,是机器学习领域的一种正则化方法,通常用于分类问题,目的是防止模型在训练时过于自信地预测标签,改善泛化能力差的问题。

背景

对于分类问题,我们通常认为训练数据中标签向量的目标类别概率应为1,非目标类别概率应为0。传统的one-hot编码的标签向量y_i为,

y_i=begin{cases}1,quad i=target\ 0,quad ine target end{cases}\

在训练网络时,最小化损失函数H(y,p)=-sumlimits_i^K{y_ilogp_i},其中p_i由对模型倒数第二层输出的logits向量z应用Softmax函数计算得到,

p_i=dfrac{exp(z_i)}{sum_j^Kexp(z_j)}

传统one-hot编码标签的网络学习过程中,鼓励模型预测为目标类别的概率趋近1,非目标类别的概率趋近0,即最终预测的logits向量(logits向量经过softmax后输出的就是预测的所有类别的概率分布)中目标类别z_i的值会趋于无穷大,使得模型向预测正确与错误标签的logit差值无限增大的方向学习,而过大的logit差值会使模型缺乏适应性,对它的预测过于自信。

在训练数据不足以覆盖所有情况下,这就会导致网络过拟合,泛化能力差,而且实际上有些标注数据不一定准确,这时候使用交叉熵损失函数作为目标函数也不一定是最优的了。

数学定义

label smoothing结合了均匀分布,用更新的标签向量hat y^i来替换传统的ont-hot编码的标签向量y_{hat}

hat y_{i}=y_{hot}(1-alpha) alpha/K \

其中K为多分类的类别总个数,αα是一个较小的超参数(一般取0.1),即

hat y_i=begin{cases}1-alpha,quad i=target\ alpha/K,quad ine target end{cases}\

这样,标签平滑后的分布就相当于往真实分布中加入了噪声,避免模型对于正确标签过于自信,使得预测正负样本的输出值差别不那么大,从而避免过拟合,提高模型的泛化能力。

效果

NIPS 2019上的这篇论文When Does Label Smoothing Help?用实验说明了为什么Label smoothing可以work,指出标签平滑可以让分类之间的cluster更加紧凑,增加类间距离,减少类内距离,提高泛化性,同时还能提高Model Calibration(模型对于预测值的confidences和accuracies之间aligned的程度)。但是在模型蒸馏中使用Label smoothing会导致性能下降。

从标签平滑的定义我们可以看出,它鼓励神经网络选择正确的类,并且正确类和其余错误的类的差别是一致的。与之不同的是,如果我们使用硬目标,则会允许不同的错误类之间有很大不同。基于此论文作者提出了一个结论:标签平滑鼓励倒数第二层激活函数之后的结果靠近正确的类的模板,并且同样的远离错误类的模板。

作者设计了一个可视化的方案来证明这件事情,具体方案为:(1)挑选3个类;(2)选取通过这三个类的模板的标准正交基的平面;(3)将倒数第二层激活函数之后的结果映射到该平面。作者做了4组实验,第一组实验为在CIFAR-10/AlexNet(数据集/模型)上面“飞机”、“汽车”和“鸟”三类的结果,可视化结果如下所示:

从中我们可以看出,加了标签平滑之后(后两张图),每个类聚的更紧了,而且和其余类的距离大致一致。第二组实验为在CIFAR-100/ResNet-56(数据集/模型)上的实验结果,三个类分别为“河狸”、“海豚”与“水獭”,我们可以得到类似的结果:

在第三组实验中,作者测试了在ImageNet/Inception-v4(数据集/模型)上的表现,三个类分别为“猫鼬”、“鲤鱼”和“切刀肉”,结果如下:

因为ImageNet有很多细粒度的分类,可以用来测试比较相似的类之间的关系。作者在第四组实验中选择的三个类分别为“玩具贵宾犬”、“ 迷你贵宾犬”和“鲤鱼”,可以看出前两个类是很相似的,最后一个差别比较大的类在图中用蓝色表示,结果如下:

可以看出在使用硬目标的情况下,两个相似的类彼此比较靠近。但是标签平滑强制要求每个示例与所有剩余类的模板之间的距离相等,这就导致了后两张图中两个类距离较远,这在一定程度上造成了信息的损失。

代码实现

pytorch部分代码

代码语言:python代码运行次数:0复制
class LabelSmoothing(nn.Module):
    def __init__(self, size, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        #self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing#if i=y的公式
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
    
    def forward(self, x, target):
        """
        x表示输入 (N,M)N个样本,M表示总类数,每一个类的概率log P
        target表示label(M,)
        """
        assert x.size(1) == self.size
        true_dist = x.data.clone()#先深复制过来
        #print true_dist
        true_dist.fill_(self.smoothing / (self.size - 1))#otherwise的公式
        #print true_dist
        #变成one-hot编码,1表示按列填充,
        #target.data.unsqueeze(1)表示索引,confidence表示填充的数字
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        self.true_dist = true_dist

        return self.criterion(x, Variable(true_dist, requires_grad=False))
        
loss_function = LabelSmoothing(num_labels, 0.1)

tensorflow代码实现

代码语言:txt复制
def smoothing_cross_entropy(logits,labels,vocab_size,confidence):
  with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
    # Low confidence is given to all non-true labels, uniformly.
    low_confidence = (1.0 - confidence) / to_float(vocab_size - 1)

    # Normalizing constant is the best cross-entropy value with soft targets.
    # We subtract it just for readability, makes no difference on learning.
    normalizing = -(
        confidence * tf.log(confidence)   to_float(vocab_size - 1) *
        low_confidence * tf.log(low_confidence   1e-20))

    soft_targets = tf.one_hot(
          tf.cast(labels, tf.int32),
          depth=vocab_size,
          on_value=confidence,
          off_value=low_confidence)
    xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=logits, labels=soft_targets)
    return xentropy - normalizing

Ref

  1. https://www.jiqizhixin.com/articles/2019-07-09-7
  2. https://www.cnblogs.com/irvingluo/p/13873699.html
  3. https://proceedings.neurips.cc/paper/2019/file/f1748d6b0fd9d439f71450117eba2725-Paper.pdf

0 人点赞