模型蒸馏升级!高温蒸馏:Softmax With Temperature

2022-12-06 15:40:41 浏览数 (1)

单位 | 上海交通大学博士生

转自| paperweekly

问题来源

最近读到一篇模型蒸馏的文章 [1],其中在设计软标签的损失函数时使用了一种特殊的 softmax:

文章中只是简单的提了一下,其中 T 是 softmax 函数的温度超参数,而没有做过多解释。这说明这种用法并非其首创,应该是流传已久。经过一番调研和学习,发现知乎上最高赞的文章《深度学习中的 temperature parameter 是什么》[13] 对超参数 T 的讲解具有很强的误导性,所以在此重新写一篇文章为其正名。

本文的标题有两个双关。一个是知识蒸馏的方法用于深度学习,同时也需要深入学习;另一个则是本文的核心:蒸馏中如何合理运用温度,让隐藏的知识更好地挥发和凝结。下面我将详细讲解以上 softmax 公式中温度系数的由来以及它起到的作用。

蒸馏模型

模型蒸馏或知识蒸馏,最早在 2006 年由 Buciluǎ 在文章 Model Compression [14] 中提出(很多博主把人名都写错了。其后,Hinton 进行了归纳和发展,并在 2015 年发表了经典之作 Distilling the Knowledge in a Neural Network [15]。正是在这篇文章 [2] 中,Hinton 首次提出了 Softmax with Temperature 的方法。

先简要概括一下模型蒸馏在做什么。出于计算资源的限制或效率的要求,深度学习模型在部署推断时往往需要进行压缩,模型蒸馏是其中一种常见方法。将原始数据集上训练的重量级(cumbersome)模型作为教师,让一个相对更轻量的模型作为学生。

对于相同的输入,让学生输出的概率分布尽可能的逼近教师输出的分布,则大模型的知识就通过这种监督训练的方式「蒸馏」到了小模型里。小模型的准确率往往下降很小,却能大幅度减少参数量,从而降低推断时对 CPU、内存、能耗等资源的需求。

我们知道模型在训练收敛后,往往通过 softmax 的输出不会是完全符合 one-hot 向量那种极端分布的,而是在各个类别上均有概率,推断时通过 argmax 取得概率最大的类别。Hinton 的文章就指出,教师模型中在这些负类别(非正确类别)上输出的概率分布包含了一定的隐藏信息。比如 MNIST 手写数字识别,标签为 7 的样本在输出时,类别 7 的概率虽然最大,但和类别 1 的概率更加接近,这就说明 1 和 7 很像,这是模型已经学到的隐藏的知识。

我们在使用 softmax 的时候往往会将一个差别不大的输出变成很极端的分布,用一个三分类模型的输出举例:

可以看到原本的分布很接近均匀分布,但经过 softmax,不同类别的概率相差很大。这就导致类别间的隐藏的相关性信息不再那么明显,有谁知道 0.09 和 0.24 对应的类别很像呢?为了解决这个问题,我们就引入了温度系数。

温度系数

我们看看对于随机生成的相同的模型输出,经过不同的函数处理,分布会如何变化:

灵感来源:https://www.youtube.com/watch?v=tOItokBZSfU

反对意见

在最高赞的那篇文章中提到:

图源:https://nni.readthedocs.io/en/stable/sharings/kd_example.html

交叉熵的梯度

softmax 的梯度

当 时

当 时

代入链式法则,最终的梯度为(推导参考了 [6][7])

随着训练的进行,我们将 t 变小,也可以称作降温,类似于模拟退火算法,这也是为什么要把 t 称作温度参数的原因。变小模型才能收敛。

可以这样理解,温度系数较大时,模型需要训练得到一个很陡峭的输出,经过 softmax 之后才能获得一个相对陡峭的结果;温度系数较小时,模型输出稍微有点起伏,softmax 就很敏感地把分布变得尖锐,认为模型学到了知识。

所以,使用一个固定的小于 1 的温度系数是合理的,这也是那篇文章里提到的推荐系统所做的,它没有降温过程,直接设置了 T=0.05 。如果大家在哪篇文章中看到了降温过程,还请在评论区指正。

其他场景

这里我们天马行空地设想一个场景:在一些序列生成任务中,比如 seq2seq 的机器翻译模型,或者是验证码识别的 CTC 算法 [9] 中,输出的每一个时间步都会有一个分布。最终的序列会使用 BeamSearch [10] 或者 Viterbi [11] 等算法搜索 Top-K 概率的序列。

这类方法介于逐时间步 argmax 的完全贪心策略和全局动态规划的优化策略之间。虽然 BeamSearch 中我们不需要提前 softmax,但假如我们做了带温度系数的 softmax,就可以控制输出分布的尖锐程度。对于这类逐步计算累积概率的算法,在每个时间步的概率分布较为均匀时就容易输出不同的结果。所以在这类问题下,高温可能导致输出序列的多样性。

对于这类场景,我没有进行严格证明也没有很深的经验,只是一个猜想。这里有类似的说法 [12],但都不能作为参考依据。大家感兴趣的话可以将 softmax with temperature 引入 BeamSearch 看看会不会对输出的丰富性造成影响。假如算法只依赖每个时间步的概率大小关系,那输出就是确定的,说明我们猜想失败。或者有相关经验的同学也可以在评论区给出参考文献。

后话

写完这篇文章才发现,潘小小【经典简读】知识蒸馏(Knowledge Distillation)经典之作 [17] 一文中已有类似的探讨。尽管如此,我相信这篇文章还是可以起到一定的科普作用,让那些和我一样对知识蒸馏不太了解的同学,从温度系数这个关键词入手,能够快速得到想要的答案。

读完 Hinton 的文章,有两个强烈的感受:一是感觉他太牛了,3 句话让我读了 18 遍,全文很少用公式,基本没有配图,但把算法讲得清清楚楚;二就是,他的写作中长从句实在太多了,一句话 60 个单词,读起来很不友好。如果对这篇文章感兴趣,也可以看上面潘小小的那篇解读。文章最后讲到了一种和 MOE 很像的分布式集成学习方法,在潘的文章中没有介绍,由于这不是今天的主题,所以我也没用笔墨,大家如果对这部分感兴趣也可以来找我讨论。

说出来很难相信,我其实不是做 AI 方向的,我是做系统的,所以欢迎大家怼我(°ー°〃)。

参考文献

[1] Group knowledge transfer: Federated learning of large cnns at the edgehttps://proceedings.neurips.cc/paper/2020/file/a1d4c20b182ad7137ab3606f0e3fc8a4-Paper.pdf

[2]Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531

[3] PR-009: Distilling the Knowledge in a Neural Network (Slide: English, Speaking: Korean) https://www.youtube.com/watch?v=tOItokBZSfU

[4] What is the role of temperature in Softmax?https://stats.stackexchange.com/questions/527080/what-is-the-role-of-temperature-in-softmax#answer-527082

[5] Knowledge Distillation on NNIhttps://nni.readthedocs.io/en/stable/sharings/kd_example.html

[6] softmax, CrossEntropyLoss 与梯度计算公式https://blog.csdn.net/jiongjiongai/article/details/88324000

[7] 关于Softmax的数值稳定性和梯度反向传播https://zhuanlan.zhihu.com/p/92714192

[8] What is the temperature parameter in deep learning?https://www.quora.com/What-is-the-temperature-parameter-in-deep-learning

[9] 详解CTChttps://zhuanlan.zhihu.com/p/42719047

[10] 文本生成解码之 Beam Searchhttps://zhuanlan.zhihu.com/p/43703136

[11] 如何通俗地讲解 viterbi 算法?https://www.zhihu.com/question/20136144/answer/763021768

[12]What is Temperature in LSTM? https://www.quora.com/What-is-Temperature-in-LSTM

[13] https://zhuanlan.zhihu.com/p/132785733

[14] https://dl.acm.org/doi/abs/10.1145/1150402.1150464

[15] https://arxiv.org/abs/1503.02531

[16] https://nni.readthedocs.io/en/stable/sharings/kd_example.html

[17] https://zhuanlan.zhihu.com/p/102038521

0 人点赞