关于"知识蒸馏",你想知道的都在这里!

2021-12-20 16:07:40 浏览数 (1)

作者:炼丹小生

简介

"蒸馏",一个化学用语,在不同的沸点下提取出不同的成分。知识蒸馏就是指一个很大很复杂的模型,有着非常好的效果和泛化能力,这是缺乏表达能力的小模型所不能拥有的。因此从大模型学到的知识用于指导小模型,使得小模型具有大模型的泛化能力,并且参数量显著降低,压缩了模型提升了性能,这就是知识蒸馏。<Distilling the Knowledge in a Neural Network>这篇论文首次提出了知识蒸馏的概念,核心思想就是训练一个复杂模型,把这个复杂模型的输出和有label的数据一并喂给了小网络,所以知识蒸馏一定会有个复杂的大模型(teacher model)和一个小模型(student model)。

为什么要蒸馏?

模型越来越深,网络越来越大,参数越来越多,效果越来越好,但是计算复杂度呢?一并上升,蒸馏就是个特别好的方法,用于压缩模型的大小。

  • 提升模型准确率:如果你不满意现有小模型的效果,可以训练一个高度复杂效果极佳的大模型(teacher model),然后用它指导小模型达到你满意的效果。
  • 降低模型延迟,压缩网络参数:网络延迟大?像是bert这种大模型,是否可以用一个一层,减少head数的简单网络去学习bert呢,这样不仅提升了简单网络的准确率,也实现了延迟的降低。
  • 迁移学习:比方说一个老师知道分辨猫狗,另一个老师知道分辨香蕉苹果,那学生从这两个老师学习就能同时分辨猫狗和香蕉苹果。

顺便回顾下之前探讨过的模型压缩5种方法:

  • Model pruning
  • Quantification
  • Knowledge distillation
  • Parameter sharing
  • Parameter matrix approximation

理想情况下,我们是希望同样一份训练数据,无论是大模型还是小模型,他们收敛的空间重合度很高,但实际情况由于大模型搜索空间较大,小模型较小,他们收敛的重合度就较低,知识蒸馏能提升他们之间的重合度使得小模型有更好的泛化能力。

知识蒸馏的方法

知识蒸馏最基础的框架:

使用Teacher-Student model,用一个非常大而复杂的老师模型,辅助学生模型训练。老师模型巨大复杂,因此不用于在线,学生模型部署在线上,灵活小巧易于部署。知识蒸馏可以简单的分为两个主要的方向:target-based蒸馏,feature-based蒸馏。

Target distillation-Logits method

上文提到的那篇最经典的论文就是该方法一个很好的例子。这篇论文解决的是一个分类问题,既然是分类问题模型就会有个softmax层,该层输出值直接就是每个类别的概率,在知识蒸馏中,因为我们有个很好的老师模型,一个最直接的方法就是让学生模型去拟合老师模型输出的每个类别的概率,也就是我们常说的"Soft-target"。

Hard-target and Soft-target

模型要能训练,必须定义loss函数,目标就是让预测值更接近真实值,真实值就是Hard-target,loss函数会使得偏差越来越小。在知识蒸馏中,直接学习每个类别的概率(老师模型预估的)就是soft-target。

Hard-target:类似one-hot的label,比如二分类,正例是1,负例是0。

Soft-target:老师模型softmax层输出的概率分布,概率最大的就是正类别。

知识蒸馏使得老师模型的soft-target去指导用hard-target学习的学生模型,为什么是有效的呢?因为老师模型输出的softmax层携带的信息要远多于hard-target,老师模型给学生模型不仅提供了正例的信息,也提供了负例的概率,所以学生模型可以学到更多hard-target学不到的东西。

知识蒸馏具体方法:

神经网络用softmax层去计算各类的概率:

但是直接使用softmax的输出作为soft-target会有其他问题,当softmax输出的概率分布的熵相对较小时,负类别的label就接近0,对loss函数的共享就非常小,小到可以忽略。所以可以新增个变量"temperature",用下式去计算softmax函数:

当T是1,就是以前的softmax模型,当T非常大,那输出的概率会变的非常平滑,会有很大的熵,模型就会更加关注负类别。

具体蒸馏流程如下:

1.训练老师模型;

2.使用个较高的温度去构建Soft-target;

3.同时使用较高温度的Soft-target和T=1的Soft-target去训练学生模型;

4.把T改为1在学生模型上做预估。

老师模型的训练过程非常简单。学生模型的目标函数可以同时使用两个loss,一个是蒸馏loss,另一个是本身的loss,用权重控制,如下式所示:

老师和学生使用相同的温度T,vi适合zi指softmax输出的logits。L_hard用的就是温度1。

L_hard的重要性不言而喻,老师也可能会教错!使用L_hard能避免老师的错误传递给学生。L_soft和L_hard之前的权重也比较重要,实验表明L_hard权重较小往往带来更好的效果,因为L_soft的梯度贡献大约是1/T^2,所以L_soft最好乘上一个T^2去确保两个loss的梯度贡献等同。

一种特殊形式的蒸馏方式:Direct Matching Logits

直接使用softmax层产出的logits作为soft-target,目标函数直接使用均方误差,如下所示:

和传统蒸馏方法相比,T趋向于无穷大时,直接拟合logits和拟合概率是等同的(证明略),所以这是一种特殊形式的蒸馏方式。

关于温度:

一个较高的温度,往往能蒸馏出更多知识,但是怎么去调节温度呢?

  • 最原始的softmax函数就是T=1,当T < 1,概率分布更"陡",当T->0,输出值就变成了Hard-target,当T > 1,概率分布就会更平滑。
  • 当T变大,概率分布熵会变大,当T趋于无穷,softamx结果就均匀分布了。
  • 不管T是多少,Soft-target会携带更多具有倾向性的信息。

T的变化程度决定了学生模型有多少attention在负类别上,当温度很低,模型就不太关注负类别,特别是那些小于均值的负类别,当温度很高,模型就更多的关注负类别。事实上负类别携带更多信息,特别是大于均值的负类别。因此选对温度很重要,需要更多实验去选择。T的选择和学生模型的大小关系也很大,当学生模型相对较小,一个较小的T就足够了,因为学生模型没有能力学习老师模型全部的知识,一些负类别信息就可以忽略。

除此以外,还有很多特别的蒸馏思想,如intermediate based蒸馏,如下图所示,蒸馏的不仅仅是softmax层,连中间层一并蒸馏。

0 人点赞