前文(【科普】联邦知识蒸馏概述与思考)提到知识蒸馏是一种模型压缩方法,通过利用复杂模型(Teacher Model)强大的表征学习能力帮助简单模型(Student Model)进行训练,主要分为两个步骤:
1)提取复杂模型的知识,在这里知识的定义有很多种,可以是预测的logits、模型中间层的输出feature map、也可以是模型中间层的attention map,主要就是反映了教师模型的学习能力,是一种表征的体现;
2)将知识迁移/蒸馏到学生模型中去,迁移的方式也有很多种,主要是各种loss function的实现,有L1 loss、L2 loss以及KL loss等手段。
知识蒸馏可以在保证模型的性能前提下,大幅度的降低模型训练过程中的通信开销和参数数量,知识蒸馏的目的是通过将知识从深度网络转移到一个小网络来压缩和改进模型。
这很适用于联邦学习,因为联邦学习是基于服务器-客户端的架构,需要确保及时性和低通信,因此最近也提出很多联邦知识蒸馏的相关论文与算法的研究,接下来我们基于算法解析联邦蒸馏学习。
▊ FL-FD 数据增强的联邦蒸馏算法【1】
在联邦学习(Federated Learning: FL)中,在每个设备端执行训练过程需要与模型大小成比例的通信开销,从而禁止使用大型模型,因此,作者寻求在非IID私有数据下可以实现通信高效的设备上ML方法。
作者提出联邦蒸馏(FD)算法,这是一种分布式在线知识蒸馏方法,其通信有效成本的大小不取决于模型大小,而取决于输出尺寸。在进行联邦蒸馏之前,我们通过联邦增强(FAug)来纠正非IID训练数据集。
这是一种使用生成对抗网络(GAN)进行的数据增强方案,该数据增强方案在隐私泄露和通信开销之间可以进行权衡取舍。经过训练的GAN可以使每个设备在本地生成所有设备的数据样本,从而使训练数据集成为IID分布。
联邦蒸馏(FD):在FD中,每台设备都将自己视为学生,并将其他所有设备的平均模型输出视为其老师的输出。每个模型输出是一组通过softmax函数归一化后的logit值,此后称为logit向量,其大小由标签数给出。
使用交叉熵来周期性地测量师生的输出差异,交叉熵成为学生的损失调整器,称为蒸馏调整器,从而在培训过程中获得其他设备的知识,具体损失是:KDLoss(Local_Logit,Global_Logit) CELoss(Local_Logit,Local_Lable)。FD中的每个设备都存储着本地每个标签的平均logit向量,并定期将这些本地平均logit向量上载到服务器。
服务器将从所有设备上载的本地平均Logit向量平均化,从而得出每个标签的全局平均Logit向量。所有标签的全局平均logit向量被下载到每个设备。然后,当每台设备进行蒸馏的时候,其教师的输出为与当前训练样本的标签具有相同标签的全局平均logit向量。具体如下图1所示。
图1:联邦蒸馏(FD)示意图【1】
联邦增强(FAvg):因为蒸馏最好在具有相同数据集的效果下进行,由于不同设备之间具有异质性所以在蒸馏前进行数据增强可以提升蒸馏效果。FAug中每个设备都可以识别数据样本中缺少的标签,称为目标标签,并通过无线链路将这些目标标签的少量种子数据样本上载到服务器。
服务器则会通过例如Google视觉数据图像搜索等方法对上传的种子数据样本进行超采样,并使用这些数据来训练一个GAN。
最后,下载经过训练的GAN生成器使每个设备补充目标标签,直到达到IID训练数据集为止。FAug的操作需要确保用户生成的数据的私密性。
实际上,每台设备的数据生成偏差(即目标标签)都可以轻松地显示其隐私敏感信息,为了使这些目标标签对服务器不公开,每个设备还将从目标标签以外的其他标签进行上载(冗余数据样本),由此减少了从每个设备到服务器的隐私泄漏。
事实上,模型的输出精度会随着训练的进行而增加,因此,在局部logit平均过程中,最好采用加权平均值随着局部计算时间的增加而增加,即当模型采用整体损失函数:a * KDLoss(Local_Logit,Global_Logit) CELoss(Local_Logit,Local_Lable) * (1-a),随着迭代次数的增加,a应该逐渐减小(模型的输出精度会随着训练的进行而增加,所以本地模型比重应该增大)。具体伪代码如下图2所示。
图2:FL-FD伪代码【1】
总结一下FL-FD算法的过程:
1)每个设备都把自己当作一个学生,并将所有其他设备的平均模型输出视为其老师的输出;
2)FD中的每个设备存储每个标签的平均logit向量,并定期将这些本地平均logit向量上传到服务器;
3)对于每个标签,对所有设备上传的本地平均logit向量进行平均,从而得到每个标签的全局平均logit向量;
4)所有标签的全局平均logit向量都被下载到每个设备上,进行蒸馏损失计算,其教师的输出被选择为与当前训练样本的标签相同的全局平均logit向量。
▊ DS-FL 基于蒸馏的半监督联邦【2】
以往的FL框架以及FedAvg算法都是通过传输加密后的梯度信息或者参数信息去进行聚合然后广播,最后让客户端更新,但是这大大加大了通信代价因为通信代价正比于上传的参数信息。所以在具有与FL相当的模型性能的同时,如何设计可根据模型大小在通信效率方面进行扩展的FL框架?
本文利用客户之间共享的无标签开放数据来增强模型性能,提出了一种基于蒸馏的半监督算法(DS-FL),该算法在客户端上传本地模型的输出,而不是本地模型的梯度或参数信息,即DS-FL的通信成本仅取决于模型输出的尺寸,而不会根据模型的参数信息等加大而扩展。
客户端首先训练本地数据,接着利用训练所得模型预测开放数据集,得到logit,所有客户端都向服务器上传logit,服务器会平均这些logits得到全局logit并且发放给每个客户端,每个客户端基于开放数据集以及服务器发放的对应logit(label)再次在开放数据集上蒸馏训练本地模型(由于数据的增强效应所以本地模型的性能也得到了增强)。
但是由于每个客户端之间的异质性导致了每个客户端对开放数据集的输出logit也具有异质性,针对此情况提出ERA方法(对于每个输出进行Softmax-T算法,即类似于蒸馏但是设置温度T=0.1,这样进一步锐化了每个客户端输出,保证经过softmax的logit(one-hot向量)某一列值足够大(即保证在服务器端平均的时候,不会出现好几个数值的可能性差不多的情况),从而减少模型输出异质性带来的影响)。
DS-FL算法:基于利用未标记开放数据集的想法,我们提出了一种基于蒸馏的半监督FL(DS-FL)算法,该算法在移动设备之间交换本地模型的输出,而不是典型框架所采用的模型参数交换。在提出的DS-FL中,通信成本仅依赖于模型的输出尺寸,而不根据模型大小进行扩展。
交换的模型输出用于标记开放数据集的每个示例,从而创建一个额外标记的数据集。利用新标记的数据集对局部模型进行进一步训练,由于数据增强效应,增强了模型的性能。具体如下图3所示。
图3:DS-FL算法框架图【2】
ERA算法:减少全局对数熵的动机是为了加速和稳定DS-FL,特别是在非IID数据分布中;因此,在简单的聚合方法中,很难使用如此不适当的高熵logit进行训练,因此,训练的成功需要减少熵。在所提出的DS-FL中,设备数据集的异质性导致了每个数据样本的模糊性,降低了训练的收敛性。
为了防止这种情况发生,作者提出了熵减少平均,其中聚合的模型输出被有意地锐化。ERA算法主要有以下两个优点:
1)锐化标签来加快收敛速度:针对联邦蒸馏中的平均标签聚合而言,ERA通过锐化每个logits,从而加快收敛速度;
2)抵御有害客户端的攻击:减少全局对数熵的另一个有利结果是增强了对破坏本地对数和通知开放数据的各种攻击的鲁棒性;在这些被攻击情况下,与非IID数据分布类似,简单聚合方法平均局部对数产生的全局对数熵较高,导致模型性能较差;降低全局logit的熵有望增强对客户端攻击的鲁棒性。具体如下图4所示。
图4:ERA算法【2】
▊ FedGEN 联邦无数据蒸馏【3】
最近出现了利用知识蒸馏来解决联邦学习中的用户异构性问题的想法,具体是通过使用来自异构用户的聚合知识来优化全局模型,而不是直接聚合用户的模型参数。然而这种方法依赖于代理数据集(proxy dataset),如果没有这proxy dataset,该方法便是不切实际的。此外,集成知识没有被充分利用来指导局部模型的训练,这可能反过来影响聚合模型的性能。
基于上述挑战,这篇文章提出了一种data-free知识蒸馏法来解决FL中的异构性问题,该方法称为FeDGen。其中服务器学习一个轻量级生成器,以data-free的方式集成用户信息,然后广播给用户,使用学习到的知识作为"归纳偏置"来调节局部训练。("归纳偏置"就是基于先验知识对目标模型的判断,将无限可能的目标函数约束在一个有限的假设类别之中)。
归纳偏置:机器学习试图去建造一个可以学习的算法,用来预测某个目标的结果。要达到此目的,要给于学习算法一些训练样本,样本说明输入与输出之间的预期关系。然后假设学习器在预测中逼近正确的结果,其中包括在训练中未出现的样本。既然未知状况可以是任意的结果,若没有其它额外的假设,这任务就无法解决。这种关于目标函数的必要假设就称为归纳偏置。
FedGen学习一个仅从用户模型的预测规则导出的生成模型(在给定目标标签的情况下,该模型可以产生与用户预测的集合一致的特征表示)。该生成器随后被广播给用户,用户从潜在空间(生成器产生的分布空间)采样得到的增广样本帮助模型训练(该潜在空间体现从其他对等用户提取的知识)。给定一个比输入空间小得多的潜在空间,FeDGen所学习的生成器可以是轻量级的,给当前的FL框架带来最小的开销。
FedGEN方法:FedGEN通过聚合所有客户端模型的知识(标签信息)用来得到一个生成器模型,生成器可以根据标签Y生成特征Z,服务器将生成器广播给所有客户端,客户端通过生成器生成增广样本用来帮助本地模型训练(增广样本具有归纳偏置信息),通过生成器可以提炼出全局分布数据的知识给客户端,从而实现无信息的知识蒸馏。具体来说,就是学习全体客户端的知识,然后通过生成器将这种知识发放给每个客户端。
图5:FedGEN结果展示【3】
上图图5是通过FedGEN方法得到的,G_w基于用户模型的预测规则进行学习,目的是融合来自用户模型的聚合信息来估计全局数据分布R(x|y),接着用户从G_w(x|y)进行采样,采样结果作为自身的归纳偏置,进而调整决策边界,如下图所示,在知识蒸馏KD之后,一个用户的准确率从81.2%提高到了98.4%。
FedGEN和先前研究的主要区别在于:知识被蒸馏至用户模型,而不是全局模型。因此,蒸馏出来的知识(向用户传递的归纳偏置)可以通过在潜在空间Z上进行分布匹配,直接调节用户的学习。
▊ 参考文献
【1】Jeong, E., Oh, S., Kim, H., Park, J., Bennis, M., & Kim, S. L. (2018). Communication-efficient on-device machine learning: Federated distillation and augmentation under non-iid private data. arXiv preprint arXiv:1811.11479.
【2】Itahara, S., Nishio, T., Koda, Y., Morikura, M., & Yamamoto, K. (2020). Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data. arXiv preprint arXiv:2008.06180.
【3】Zhu, Z., Hong, J., & Zhou, J. (2021, July). Data-free knowledge distillation for heterogeneous federated learning. In International Conference on Machine Learning (pp. 12878-12889). PMLR.