题目: 更快更好的联邦学习:一种特征融合方法
会议: IEEE ICIP 2019
论文地址:https://ieeexplore.ieee.org/abstract/document/8803001
本文将解读清华大学孙立峰教授团队在2019 IEEE International Conference on Image Processing (ICIP)上发表的论文《Towards Faster and Better Federated Learning: A Feature Fusion Approach》。该论文提出了一种特征融合方法来减少联邦学习中通讯的成本,并提升了模型性能:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。
Abstract
联邦学习能够在由大量现代智能设备(如智能手机和物联网设备)组成的分布式网络上进行模型训练。然而,FedAvg算法通常需要很大的通信成本,并且性能也是一个很大的挑战,特别是当本地数据以非IID方式分布时。
因此,本文提出了一种特殊的特征融合机制来解决上述问题:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。
1.Introduction
为了充分利用设备上的数据,传统的机器学习策略需要从客户端收集数据,然后在服务器上集中训练模型,然后将模型分发给客户端,这给通信网络带来了沉重的负担并且暴露于高隐私风险(所有客户端需要暴露自己的数据)。
2016年,谷歌提出了联邦学习(Federated Learning)的概念,并首次提出了FedAvg算法,它使用本地数据对客户端执行分布式培训,并将这些模型汇总到中央服务器中以避免数据共享。 通过这种方式,减轻了隐私暴露问题。然而,进一步的研究指出,与其他因素相比,通信成本仍然是FL的主要制约因素,例如计算成本,如果模型接受非IID数据训练,则FedAvg的准确性将显着下降。
在本文中,提出了一种新的具有特征融合机制(FedFusion)的FL算法来解决上述问题。通过引入特征融合模块,在特征提取阶段之后聚合来自局部和全局模型的特征,而几乎没有额外的计算成本。这些模块使每个客户端的训练过程更加高效,并且更有针对性地处理非IID数据,因为每个客户端将为自己学习最合适的特征融合模块。
本文贡献:
•首次将特征融合机制引入联邦学习。•所提出的特征融合模块以高效和个性化的方式聚合来自本地和全局模型的特征。•实验表明本文所提出的方法在精度和泛化能力方面均优于baseline,并且将通信轮数减少了60%以上。
2.Related Work
考虑到通信成本是限制FL的主要因素,目前已经有一些学者做了相关的研究工作。比如Konecny等人在客户端到服务器通信的背景下提出了结构化和草图更新;Yao等人对设备上的培训程序引入了额外的限制,旨在拟合本地数据的同时整合来自其他客户的更多知识;Caldas等人提出federated dropout来训练客户端的子集,并将有损压缩扩展到服务器到客户端的通信。
3.Methods
在本节中,首先介绍所提出的特征融合模块,然后给出具有特征融合机制(FedFusion)的联邦学习算法。
3.1 Feature Fusion Modules
如下图所示:
其中蓝色的部分表示local模型提取的两通道特征,灰色部分表示global模型提取到的两通道特征。图1给出了三种特征融合方式:Conv, Multi和Single。特征的提取在CNN中可以理解为经过卷积和池化操作后得到的图片信息。
每一个输入的图像 都会分别被局部特征提取器 和全局特征提取器 映射到 。
随后,特征融合算子 将两个特征提取器提取到的特征进行融合: ,两个特征融合后被映射到 。
3.1.1 Conv operator
其中 表示shape为 的可学习的权重矩阵。具体操作就是将global特征和local特征进行concat(||)后进行卷积操作。
关于卷积中通道C、高度H以及宽度W的解释可见:一文读懂卷积神经网络(CNN)
3.1.2 Multi operator
Multi算子:用一个 权重矩阵来对local和global进行一个加权求和。
3.1.3 Single operator
Single算子:用一个标量 来对local和global进行一个加权求和。
经过上述操作后,global特征提取器提取到的特征和local特征提取器提取到的特征将融合成为一个新的特征,特征shape为 。
3.2 Federated Learning with Feature Fusion Mechanism
本节讲述带有特征融合机制的联邦学习策略!
本文所提出的FedFusion的典型训练迭代如下图所示:
具体来讲:
客户端在第 轮训练时,将会保留服务器发来的全局的特征提取器 ,在本地分类器进行迭代更新时,会考虑将 和 进行融合。
在训练期间, 被冻结并且引入了3.1中描述的附加特征融合模块。
在客户端上进行训练后,将与特征融合模块结合的本地模型发送到中央服务器进行模型聚合,这里使用指数移动平均策略来平滑更新。
算法伪代码:
对中央服务器:
1.初始化全局参数 2.对第r轮更新:随机选择m个客户端,然后对这m个客户端做如下操作:将全局参数 传递给客户端,算出每一个客户端返回的梯度。最后,根据这些梯度进行指数移动平均,合成新的全局参数 。
对客户端t的第r轮训练来说:
1.局部参数 ,也就是说局部模型是一个分类器,其中 是本地特征提取器(是需要通过数据来进行学习的),提取后经过F特征融合,最后再进行分类。2.对每一个bach内的数据,计算 模型的梯度,然后反向传播更新参数。注意这里的模型,实际上就是本文的创新点所在,本地训练时,模型的特征并不只是简单的本地特征,而是将上一轮的全局模型的特征提取器提取到的特征与本地特征进行融合,融合后再进行训练。3.训练结束后将最新的局部参数传递给服务器,由服务器进行指数移动平均,聚合形成新的全局参数。
4.Experiments
4.1 Experimental Setup
在实验中使用MNIST和CIFAR10作为基本数据集。
对于MNIST数字识别任务,使用与FedAvg相同的模型:具有两个5×5卷积层的CNN(第一个具有32个通道,第二个具有64个通道,每个之后是ReLU激活和2×2最大池化),512个节点的完全连接层(ReLU Random Dropout),softmax输出层。
对于CIFAR10,使用具有两个5×5卷积层的CNN(均具有64个通道,每个通道之后是ReLU激活和3×3最大池化,stride为2),两个完全连接层(第一个具有384个节点,第二个具有192个节点,每个之后是ReLU Random Dropout)和最终的softmax输出层。
数据分割方式:
1.Artificial Non-IID Partition:每个节点仅包含两种类别。2.User Specific Non-IID Partition:每个节点包含相似的类别,但是采用不同的分布,类似multi task学习。3.IID分布。
4.2 Artificial Non-IID Partition
a和b表述了在人工形成的非IID场景下, FedFusion和FedAvg的收敛图。可以看到,在相同的通讯轮数下,不进行特征融合,也就是FedAvg的表现是最差的,其精度最低。
(图有些看不清),具体的数据如表1所示:
可以看到进行特征融合后(无论哪一种特征融合),模型的精度都有所提升。
Multi融合方式的效果最好,Conv融合方式次之。
4.3 User Specific Non-IID Partition
为了模拟用户特定的非IID分区,对每个客户端的MNIST应用不同的排列,这就是之前几项研究中所谓的置换MNIST。
表2列出了达到某些精度(此处为94%和95%)的通信轮数以及通信轮数相对于FedAvg的减少:
从上表可以看出,FedFusion Conv实现了通讯轮数最大幅度的降低。
值得注意的是,用户特定的“非IID分区更接近现实的FL场景,因此在这种情况下改进更有意义。
4.4 IID Partition
如下图所示:
在IID场景下,使用Multi和Conv进行融合可以以较低的通信成本实现更高的精度。
对特征融合算子做出如下简要概括:
1.Multi算子在局部和全局特征映射之间提供灵活的选择,并且更易于解释。 权重向量 考虑了相应通道中全局特征映射的比例。当客户端数据类别存在差距时,Multi算子将学习选择最有用的特征映射。2.Conv算子更擅长整合全球和本地模型的知识。 如果客户端的数据具有相似的类别但遵循不同的分布,Conv算子的性能要好得多。3.实验表明,Single算子几乎没有改进,不推荐使用。
5. Conclusion
联邦学习巨大的通讯成本是一个需要解决的紧急问题。 在本文中,尝试从减少沟通轮次的角度进行一些改进:提出了一种新的具有特征融合模块的FL算法,并在当前较为流行的FL设置中对其进行评估。实验结果表明,该方法具有较高的精度,同时将通信轮次减少了60%以上。
未来的工作可能包括将目前的算法扩展到更复杂的模型和场景,以及将通信轮次减少策略与其他类型的方法(例如梯度估计和压缩)相结合。