这是一篇最新ICLR2022论文,Acceleration of Federated Learning with Alleviated Forgetting in Local Training,作者通过从一个灾难性遗忘的角度分析联邦学习性能不佳的原因,并进行改进提升收敛速度与精度。
《Acceleration of Federated Learning with Alleviated Forgetting in Local Trainin》
论文:https://arxiv.org/abs/2203.02645
代码:https://github.com/Zoesgithub/FedReg
▊ 1 Abstract
作者观察到,现有方法收敛速度缓慢是由于每个客户端局部训练阶段的灾难性遗忘问题造成的,这导致其他客户的先前训练数据的损失函数大幅增加。
因此作者提出了一种FedReg算法,通过对生成的伪数据的损失来调整局部训练的参数,并对全局模型学习到的先前训练数据的知识进行编码,从而大大提高收敛速度,同时可以更好的保护隐私。
▊ 2 Introduction
一些FL算法被设计要通过减少异质性问题的差异来改进FedAvg,但是当采用深度神经网络架构时,这些方法的性能仍然远不能令人满意,另一方面,最近的文献工作表明训练后的模型参数的传输并不能保证对隐私的保护,虽然DP可以防止隐私泄露,但是当DP加入FL时模型的性能持续衰减。
作者观察到,当数据为non-i.i.d时在整个客户中,本地训练的模型严重忘记了其他客户对以前的训练数据的知识(即众所周知的灾难性遗忘问题),这可能是由于本地数据分布和全局数据分布之间的差异。这种遗忘问题导致客户端损失大幅增加,我们提出FedReg通过减轻局部训练阶段的灾难性遗忘问题来降低训练中的通信成本。
FedReg通过使用生成的伪数据对局部训练参数进行正则化来减少知识遗忘,这些伪数据是通过使用修改后的局部数据对全局模型学习到的先前训练数据的知识进行编码而获得的。伪数据与本地数据中的知识的潜在冲突通过使用扰动数据得到抑制,扰动数据是通过对本地数据进行小扰动而产生的,它们有助于确保其预测值。伪数据和扰动数据的生成只依赖于从服务器接收到的全局模型和当前客户端的本地数据。
作者证明,当跨客户端的数据是非独立同分布的时,本地训练阶段的灾难性遗忘是减慢 FL 训练过程的重要因素,因此提出了一种算法 FedReg,它通过使用生成的伪数据减轻灾难性遗忘来加速 FL。
灾难性遗忘:指的是人工智能系统,如深度学习模型,在学习新任务或适应新环境时,忘记或丧失了以前习得的一些能力。当神经网络在多个任务上按顺序训练时,就会发生灾难性遗忘,在这种情况下,当前任务的最佳参数可能在先前任务的目标上表现不佳。在深度神经网络学习不同任务的时候,相关权重的快速变化会损害先前任务的表现,造成人工智能系统在原有任务或环境性能大幅下降。
▊ 3 Method
主要挑战是如何减轻每个客户对先前学习知识的遗忘,而不必在本地培训阶段访问其他客户的数据。我们首先生成伪数据,然后通过使用伪数据上的损失对局部训练的参数进行正则化来缓解灾难性遗忘问题。
生成伪数据:fast gradient sign method,FGSM,是一种对抗样本生成方法,根据本轮全局模型梯度反方向生成对抗样本:
通过生成对抗样本,基于本地模型分类结果与数据标签相差较大的数据,可以对本地模型很好起到正则化效果。
尽管上面生成的伪数据放松了约束,但由于训练过程中全局模型的不准确性,可能会导致对抗样本与本地数据冲突,导致模型学习到一些错误信息,为了进一步消除这种相互冲突的信息,对本地数据进行轻微扰动,个人理解可以增强鲁棒性:
其中扰动程度非常小,即n_p<<n_s,以确保扰动数据比伪数据更接近本地数据。
接下来,使用生成的伪数据和扰动数据,进行正则化以减轻灾难性遗忘:
其中约束 (4) 缓解了灾难性的遗忘问题,约束 (5) 消除了 (4) 中引入的冲突信息,约束 (5) 还有助于提高结果模型的鲁棒性。
再通过求解以下约束优化问题来逼近最优参数 θ(t,i)∗:
更进一步,本地模型参数在每个训练步骤中更新为:
进一步注意到,在分类问题中,伪数据可用于修改梯度从而增强隐私保护,伪数据与真实数据相比可能是包含相似的语义信息但不同的分类信息,因此伪数据增强FL隐私保护能力而不会严重降低所得模型的性能,如下所示Di代表原数据、Di_s代表伪数据:
从而通过修正梯度来增强隐私保护。
关于FedReg方法流程图如下图1所示,通过对本地模型施加正则化,加快全局模型收敛速度并提升精度。
图1:FedReg方法
▊ 4 Experiments
数据集:MNIST、EMNIST、CIFAR-10、CIFAR-100、CT images related to COVID-19(一个基于COVID-19胸部CT图片)。
收敛速度比较:与基线方法相比,FedReg 只需要更少的通信轮次达到收敛,并获得更高的最终准确度,如下表1所示。
表1:收敛率比较
减轻灾难性遗忘:如下图2所示,为了证明 FedReg 确实减轻了灾难性遗忘,在 FedReg 和 FedAvg 之间比较了其他客户端先前训练数据的损失值的增加情况。FedReg 中损失的增加幅度明显低于 FedAvg,这表明虽然 FedAvg 和 FedReg 都忘记了一些学习知识,但 FedReg 的遗忘问题并不严重。
顶行表示FedReg与FedAvg的损失情况。底行关于信息相似性数据表明在伪数据上具有高性能的模型在先前的训练数据上表现良好的概率相对较高(分布较为均匀),使用伪数据对参数进行正则化有助于减少先前训练数据的性能下降并缓解知识遗忘问题。
图2:减轻灾难性遗忘
步长超参数:当生成伪数据的步长n_s太大时,生成的伪数据与本地数据的距离/差异性会太大,无法有效地对模型参数进行正则化。另一方面,当n_s太小时,来自不准确的全局模型的冲突信息会减慢收敛速度。当 ηs 太小时,需要更多的通信轮次才能达到相同的精度;随着 ηs 的增加,通信轮数减少到最小,然后反弹回来,表明此时在局部训练阶段正则化对模型参数不太有效。
隐私保护:如下表2所示,梯度反转攻击用于从分类问题中的更新梯度中恢复信息,从下表中可以看出,在基线方法中,使用 DP 保护隐私信息的性能下降幅度很大,而使用 FedReg 在保持相似的隐私保护级别时,性能下降幅度要小得多。相比较而言,FedReg 能够保护敏感的时间信息,但是模型性能只有轻微的衰减。
表2:隐私保护比较
▊ 5 Conclusion
三点总结:
- 在这项工作中,作者提出了一种新的算法 FedReg,通过减轻局部训练阶段的灾难性遗忘问题来加快 FL 的收敛速度。生成伪数据以携带有关全局模型学习的先前训练数据的知识,而不会产生额外的通信成本或访问其他客户端提供的数据。
- 生成的伪数据包含与其他客户端之前的训练数据相似信息,因此可以通过使用伪数据对本地训练的参数进行正则化来缓解遗忘问题。
- 伪数据还可以用于防御分类问题中的梯度反转攻击,与 DP 相比,结果模型的性能只有轻微的衰减。
▊ 6 补充: Fast Gradient Sign Method
我们说到作者生成伪数据是通过Fast Gradient Sign Method,这是一种对抗样本数据生成方法,如下图3所示,横坐标表示单维x输入值,纵坐标表示损失值,函数图像是损失函数,损失值越大表示越大概率分类错误,假设灰的线上方为分类错误,下方为分类正确;
以样本点x1为例,根据公式,此时的偏导函数为负,则黑色箭头方向为扰动方向,同理x2样本在取值为正时,也沿着黑色箭头方向变化,只要我们的取值合适,就能生成对抗样本,使得分类错误。总之扰动方向就是使得损失函数变大的方向,通过扰动使得样本被分类错误。
图3:FGSM
参考文献
Xu, C., Hong, Z., Huang, M., & Jiang, T. (2022). Acceleration of Federated Learning with Alleviated Forgetting in Local Training. International Conference on Learning Representations 2022. https://openreview.net/forum?id=541PxiEKN3F.