知识蒸馏

2022-03-24 11:22:57 浏览数 (1)

知识蒸馏概述

蒸馏指的是把不纯净的水加热变成蒸汽,蒸汽冷凝之后变成冷凝水。知识蒸馏也是把有杂质的东西,大的东西变成小的东西,纯的东西。把一个大的模型(教师模型)里面的知识给萃取蒸馏出来浓缩到一个小的学生模型上。

教师网络是一个比较大的神经网络,它把知识传递给了小的学生网络。这个过程称之为蒸馏或者叫迁移。

在现在的人工智能社会,比如说计算机视觉,语音识别,自然语言处理,算法都是很大的。而真正落地的终端的算力都是非常有限的,比如说手机、智能手表、无人驾驶的汽车、物联网设备等等。教师网络可能是用海量的数据,海量的能源训练出来的一个非常臃肿的模型,现在要部署到移动终端设备上,在算力非常受限的场景下,所以需要把大模型变成小模型,把模型部署到终端上。这就是知识蒸馏的核心目的。

现在的预训练大模型的参数往往都是成指数型增加。轻量化网络模型分为四个主要路线

  1. 压缩一个已经训练好的大模型,包括知识蒸馏、权值量化、剪枝(包括权重剪枝、通道剪枝)、注意力迁移。
  2. 直接训练轻量化网络,如SqueezeNet、MobileNet(v1、v2、v3)、Mnasmet、ShuffleNet(v1、v2)、XCeption、EfficientNet、EfficentDet
  3. 加速卷积运算,如im2col GEMM(将卷积操作转成矩阵操作)、Winograd、低秩分解
  4. 硬件部署,TensorRT、JetSon、Tensorflow-slim、Tensorflow-Lite、Openvino、FPGA集成电路

知识的表示与迁移

我们以这幅图为例,假如说我们将左边的这匹马的图片喂给一个图像分类的神经网络模型,它可能会有很多的类别,每一个类别都识别出一个概率。我们在训练网络的时候只是告诉网络这是一匹马,至于它是不是驴,是不是汽车,这些概率一律为0。这里我们称之为hard targets——马(1)、驴(0)、汽车(0)。我们是用hard targets来训练网络的,但是这个并不科学,这样的标签等同于告诉网络这只是一匹马,不是驴也不是车,并且不是驴不是车的程度是相等的。通过我们人眼可以看出来,这匹马和驴子是有一些相似性的,它更像驴子而更不像汽车。

如果我们把这匹马的图片喂到一个已经训练好的神经网络里面,网络可能会给出一个这样的结果——soft targets,马(0.7)、驴(0.25)、汽车(0.05)。表示它是马的概率为0.7,预测为驴的概率为0.25,预测为汽车的概率为0.05。那么这个soft targets就足够的科学,它不仅告诉我们这个马的概率是最大的,那么它七成是一匹马,还有二点五成的概率是一头驴,而汽车的概率只有零点五成。soft targets就传递了更多的信息。所以我们在训练教师网络的时候可以用hard targets去训练,而产生出的soft targets能够传递出更多的信息,那么就可以用soft targets去训练学生网络。

比方说手写数字集中的这个2,第一个2像2但也像3,第二个2像2但也像7。由此可见,Soft Label包含了更多的“知识”和“信息”,像谁不像谁,有多像,有多不像,特别是非正确类别概率的相对大小。而在Hard Label里面非正确类别的概率一律被抹成了0.

在上图中,除了3和7以外,其他的数字的概率都很小,现在我们要把这些其他数字的概率放大,充分暴露出来它们的差别,需要引入一个温度蒸馏T

温度蒸馏T

一般我们做多分类的时候都使用的是softmax这个激活函数,softmax的函数如下

现在我们给该函数增加一个温度蒸馏T

这里T=1时即为softmax,比如我们在做动物图像分类的时候,下图中红线的部分即为原始的softmax的概率。

而蓝线的部分是加入了温度蒸馏T=3的情况,则原来比较hard的部分就变得更加soft,高的概率被压低,而低的概率被提升。但是它的相对大小依然是固定的。T越高,这个targets就会变的越soft。

从上图中我们可以看到当T=1的时候,曲线非常的陡峭,是马的概率是非常明显的,T=100的时候,我们几乎看不出这几种动物的分类有啥区别的,曲线非常的平缓。由此可知,我们可以增加T来使得原本比较hard的targets变得更soft。用更软的targets去训练学生网络,那些非正确类别的信息就暴露的越彻底。

在上图中,在学生网络中,通过神经网络前向运算得出来的logit分别为猫-5,狗2,驴7,马9,当我们使用T=1的softmax进行分类,那么得出来的概率中马的概率非常的高,而使用T=3中,马的概率有所下降,但还是最高,而其他种类的动物有所上升。教师网络也是一样。

知识蒸馏的过程

我们来看一下学生网络、教师网络到底是怎么样来进行蒸馏学习的,首先有一个已经训练好的教师网络,然后我们把很多数据喂给教师网络,这里会给一个温度为t的时候的softmax。我们再把数据喂给学生网络,学生网络可能是还没有开始训练的网络,也给它一个温度t的softmax,然后用温度为t的softmax的教师网络的soft labels与温度为t的学生网络的softmax的预测值做一个损失函数,让它俩越接近越好。就是学生在模拟老师的预测结果了。

学生网络自己经过一个T=1的普通的softmax,和ground truth,也就是hard label再做一个损失函数,希望让它俩更接近。也就是说学生网络即要在温度为t的时候的结果跟教师网络的预测结果更接近,也要兼顾T=1的时候的预测结果和标准答案更接近。也就是总体的损失函数为distillation loss和student loss两项的加权求和。那么对于之前那个动物分类的总体损失函数如下

知识蒸馏的应用场景

  1. 模型压缩
  2. 优化训练,防止过拟合(潜在的正则化)
  3. 无限大、无监督数据集的数据挖掘
  4. 少样本、零样本学
  5. 迁移学和知识蒸馏

上图中的Baseline我们可以理解成完全使用hard targets进行训练的网络,当我们取大量的数据100%的投入到模型中,我们可以看到它训练集的准确率是63.4%,测试集的准确率是58.9%。而当我们使用3%的数据量投入到网络中训练,它的训练集的准确率为67.3%,而测试集的准确率只有44.5%,这说明过拟合了。但是当我们使用教师网络传递过来的soft targets来训练该网络,也只使用3%的数据集,则训练数据集的准确率为65.4%,但是它测试集的准确率却达到了接近100%数据集的hard targets的水平,达到了57.0%的水平。

当我们把海量的图片,没有标注的图片喂给一个已经训练好的教师网络,此时教师网络可以通过正向计算将结果soft targets直接作为标注信息传递给学生网络,此时即便在没有labels的情况下,学生网络依然可以进行训练,达到无监督学的目的。

迁移学指的是把一个领域的模型泛化到另一个领域,比如说用一个识别X光胸片的数据集去训练一个原本识别猫狗的模型,那么这个识别猫狗的模型就慢慢学会了识别X光胸片的各种疾病。知识蒸馏是把一个模型的知识迁移到另一个模型上。

知识蒸馏的原理

上图中大的绿色的矩形为非常大的教师网络,中间的蓝色的矩形是学生网络。如果大家都使用hard targets来训练模型,那么教师网络会慢慢收敛到红色部分的椭圆,而学生网络会慢慢收敛到浅绿色的椭圆。这两个收敛点不同,相隔比较远,并且教师网络的能力要远远高于学生网络。当使用知识蒸馏的时候,教师网络会告诉学生该如何收敛,并且慢慢靠近教师网络的收敛点,使得学生网络收敛于红色椭圆旁边的粉红色椭圆,此时学生网络的收敛点就很接近于教师网络,从而学生网络的能力可以提升到与教师网络相当。

一些开箱即用的工具

MMRazor模型压缩工具箱,包括了剪枝(Pruning)、知识蒸馏(KD)、神经架构搜索(NAS)、量化(Quantization)

https://github.com/open-mmlab/MMRazor

MMDeploy模型转换于部署工具箱,可以把各种算法变成主流模型压缩厂商的中间格式,包括ONNX,Intel的Openvino、英伟达的TensorRT、商汤的OpenPPL、腾讯优图ncnn PPLNN。

https://github.com/open-mmlab/MMDeploy

RepDistiller包含了12个SOTA知识蒸馏算法的Pytorch复现,这也是一些最先进的算法。

https://github.com/HobbitLong/RepDistiller

知识蒸馏温度T可视化

代码语言:javascript复制
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':

    logits = np.array([-5, 2, 7, 9])
    # softmax(T=1)
    softmax_1 = np.exp(logits) / sum(np.exp(logits))
    print(softmax_1)
    plt.plot(softmax_1, label='T=1')
    T = 3
    softmax_3 = np.exp(logits / T) / sum(np.exp(logits / T))
    plt.plot(softmax_3, label='T=3')
    T = 5
    softmax_5 = np.exp(logits / T) / sum(np.exp(logits / T))
    plt.plot(softmax_5, label='T=5')
    T = 10
    softmax_10 = np.exp(logits / T) / sum(np.exp(logits / T))
    plt.plot(softmax_10, label='T=10')
    T = 100
    softmax_100 = np.exp(logits / T) / sum(np.exp(logits / T))
    plt.plot(softmax_100, label='T=100')
    plt.xticks(np.arange(4), ['Cat', 'Dog', 'Donkey', 'Sorse'])
    plt.legend()
    plt.show()

运行结果

知识蒸馏训练学生教师网络

代码语言:javascript复制
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

if __name__ == '__main__':

    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    train_dataset = torchvision.datasets.MNIST(
        root="mnist/",
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )
    test_dataset = torchvision.datasets.MNIST(
        root="mnist/",
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )
    train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
    # 教师网络训练模型
    class TeacherModel(nn.Module):

        def __init__(self, in_channels=1, num_classes=10):
            super(TeacherModel, self).__init__()
            self.relu = nn.ReLU()
            self.fc1 = nn.Linear(784, 1200)
            self.fc2 = nn.Linear(1200, 1200)
            self.fc3 = nn.Linear(1200, num_classes)
            self.dropout = nn.Dropout(p=0.5)

        def forward(self, x):
            out = x.view(-1, 784)
            out = self.fc1(out)
            out = self.dropout(out)
            out = self.relu(out)
            out = self.fc2(out)
            out = self.dropout(out)
            out = self.relu(out)
            out = self.fc3(out)
            return out

    model = TeacherModel()
    model = model.to(device)
    print(summary(model))

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    epochs = 6
    for epoch in range(epochs):
        model.train()
        for data, targets in tqdm(train_loader):
            data = data.to(device)
            targets = targets.to(device)
            preds = model(data)
            loss = loss_function(preds, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices
                num_correct  = (predictions == y).sum()
                num_samples  = predictions.size(0)
            acc = (num_correct / num_samples).item()
        model.train()
        print("Epoch:{}t Accuracy:{:.4f}".format(epoch   1, acc))

    teacher_model = model
    # 学生网络训练模型
    class StudentModel(nn.Module):

        def __init__(self, in_channel=1, num_classes=10):
            super(StudentModel, self).__init__()
            self.relu = nn.ReLU()
            self.fc1 = nn.Linear(784, 20)
            self.fc2 = nn.Linear(20, 20)
            self.fc3 = nn.Linear(20, num_classes)

        def forward(self, x):
            out = x.view(-1, 784)
            out = self.fc1(out)
            out = self.relu(out)
            out = self.fc2(out)
            out = self.fc2(out)
            out = self.relu(out)
            out = self.fc3(out)
            return out


    model = StudentModel()
    model = model.to(device)
    print(summary(model))

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    epochs = 3
    for epoch in range(epochs):
        model.train()
        for data, targets in tqdm(train_loader):
            data = data.to(device)
            targets = targets.to(device)
            preds = model(data)
            loss = loss_function(preds, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices
                num_correct  = (predictions == y).sum()
                num_samples  = predictions.size(0)
            acc = (num_correct / num_samples).item()
        model.train()
        print("Epoch:{}t Accuracy:{:.4f}".format(epoch   1, acc))

    student_model_scratch = model
    # 知识蒸馏训练模型
    teacher_model.eval()
    model = StudentModel()
    model = model.to(device)
    model.train()
    # 温度为7
    temp = 7
    hard_loss = nn.CrossEntropyLoss()
    # hard_loss权重
    alpha = 0.3
    # KL散度
    soft_loss = nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    epochs = 3
    for epoch in range(epochs):
        model.train()
        for data, targets in tqdm(train_loader):
            data = data.to(device)
            targets = targets.to(device)
            # 教师模型预测
            with torch.no_grad():
                teacher_preds = teacher_model(data)
            # 学生模型预测
            student_preds = model(data)
            # 计算学生网络的hard_loss
            student_loss = hard_loss(student_preds, targets)
            # 计算蒸馏后的预测结果及soft_loss
            ditillation_loss = soft_loss(F.softmax(student_preds / temp, dim=1),
                                         F.softmax(teacher_preds / temp, dim=1))
            # 将hard_loss和soft_loss加权求和
            loss = alpha * student_loss   (1 - alpha) * ditillation_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices
                num_correct  = (predictions == y).sum()
                num_samples  = predictions.size(0)
            acc = (num_correct / num_samples).item()
        model.train()
        print("Epoch:{}t Accuracy:{:.4f}".format(epoch   1, acc))

运行结果

代码语言:javascript复制
cpu
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
TeacherModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            942,000
├─Linear: 1-3                            1,441,200
├─Linear: 1-4                            12,010
├─Dropout: 1-5                           --
=================================================================
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0
=================================================================
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
TeacherModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            942,000
├─Linear: 1-3                            1,441,200
├─Linear: 1-4                            12,010
├─Dropout: 1-5                           --
=================================================================
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0
=================================================================
100%|██████████| 1875/1875 [00:40<00:00, 46.59it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:1	 Accuracy:0.9419
100%|██████████| 1875/1875 [00:40<00:00, 46.55it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:2	 Accuracy:0.9620
100%|██████████| 1875/1875 [00:37<00:00, 50.20it/s]
Epoch:3	 Accuracy:0.9689
100%|██████████| 1875/1875 [00:37<00:00, 50.17it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:4	 Accuracy:0.9736
100%|██████████| 1875/1875 [00:37<00:00, 50.18it/s]
Epoch:5	 Accuracy:0.9769
100%|██████████| 1875/1875 [00:37<00:00, 50.16it/s]
Epoch:6	 Accuracy:0.9794
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
StudentModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            15,700
├─Linear: 1-3                            420
├─Linear: 1-4                            210
=================================================================
Total params: 16,330
Trainable params: 16,330
Non-trainable params: 0
=================================================================
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
StudentModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            15,700
├─Linear: 1-3                            420
├─Linear: 1-4                            210
=================================================================
Total params: 16,330
Trainable params: 16,330
Non-trainable params: 0
=================================================================
100%|██████████| 1875/1875 [00:05<00:00, 371.57it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:1	 Accuracy:0.8526
100%|██████████| 1875/1875 [00:05<00:00, 369.26it/s]
Epoch:2	 Accuracy:0.8853
100%|██████████| 1875/1875 [00:05<00:00, 368.62it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:3	 Accuracy:0.8971
100%|██████████| 1875/1875 [00:08<00:00, 209.59it/s]
Epoch:1	 Accuracy:0.8410
100%|██████████| 1875/1875 [00:09<00:00, 208.31it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]Epoch:2	 Accuracy:0.8814
100%|██████████| 1875/1875 [00:08<00:00, 208.54it/s]
Epoch:3	 Accuracy:0.8921

0 人点赞