题目:Personalized Federated Learning: A Meta-Learning Approach
论文地址:https://arxiv.org/abs/2002.07948
元学习是当下比较热门的一个研究方向,本篇文章将联邦学习和一种模型不可知元学习方法MAML结合起来,提出了一种新的个性化联邦技术Per-FedAvg。
所谓元学习,就是学会学习。利用元学习得到的模型,当我们在面临一个新的任务时,经过很少的训练步骤就可以得到一个比较好的模型,而不必像经典机器学习一样,需要在一个数据集上进行大量训练。Per-FedAvg的思想类似,我们利用所有客户端的数据得到一个初始模型,然后各个客户端使用该初始模型在本地进行几次梯度下降就能得到最终模型。
1. 引言
联邦学习框架中,假设一共
个客户端,那么优化函数为:
也就是最小化各个客户端损失的均值。对每个客户端来讲,损失函数可以定义为:
传统算法的局限显而易见:在用户数据分布不完全相同的异质环境中,通过最小化平均损失得到的全局模型一旦应用于每个用户的本地数据集,可能会表现得比较糟糕。
为了应对数据的统计异质性和非IID分布所带来的挑战,需要对全局模型进行个性化处理。前面已经讲过一个比较简单的联邦个性化算法FedPer,FedPer中每个客户端都有自己的模型,所有客户端模型共享神经网络的基础层,而个性化层通过自己本地数据进行训练。如果数据量不足,通过基础层的共享可以获得一个训练比较充分的底层模型,而顶端个性化层又可以保证模型对本地数据具有较好的适应性。
与FedPer不同,Per-FedAvg的目的是获得一个初始模型,然后使用该初始模型在各个客户端的数据上进行少数几轮训练就可以得到一个较好的本地模型。通过这种方式,虽然初始模型是在所有用户上以分布式方式导出的,但每个用户的最终模型都与其他客户端的模型不同,这一点与FedPer一致。
2. Per-FedAvg
下面详细介绍Per-FedAvg的具体原理。
2.1 初始模型
我们假设所有用户的都得到了自己的初始模型,然后在本地数据上使用少数几次(比如一次)梯度下降,就可以得到自己需要的模型,那么优化目标可以定义为:
这里
为学习率。可以发现,上式与FedAvg的差别在于,Per-FedAvg中客户端需要优化的函数是在FedAvg函数的基础上进行一次梯度下降后得到的。这样,上述公式不仅可以保持联邦学习的优势(联合所有客户端数据),也可以捕捉不同用户间的差异:客户端可以根据自己的数据修改初始模型,进而得到自己的模型。
对于每个客户端,我们定义它的元函数
:
Per-FedAvg的伪代码描述如下:
为了在本地训练中对
进行更新,我们需要计算其梯度:
可以观察到,由于
的表达式中有
的梯度
,所以在计算
时我们需要计算参数的Hessian矩阵
。
计算
的代价很大,因此论文中的计算方式为:在客户端本地选取一批数据
,然后利用这批数据来得到
的一个无偏估计。即:
也就是
中所有梯度求均值。
类似地,对于
,我们同样可以取一批数据得到其无偏估计。
与FedAvg类似,Per-FedAvg中第
轮通信时,服务器将模型发送给选中的客户端,然后每个客户端执行
轮本地梯度更新:
由于
实际上是损失函数
进行一步梯度更新之后得到的,即:
那么我们首先需要选取一批数据,然后对原始的损失函数求梯度,然后更新,进而得到
的参数:
这个时候我们得到了参数
,这其中
为客户端编号,
表示当前的全局轮数,
表示本地更新的轮数。
得到
后,然后观察元函数的真实梯度计算公式:
可以发现,我们还需要对
进行求导,也就是上式右边那部分,这部分同样选取一批数据求无偏梯度。上式左边部分的二阶梯度就是对原始损失函数求二阶梯度,比较简单。
然后元函数梯度可以表示为:
这里的
,
以及
是在本地选取的三批独立的数据。
此时我们就可以对客户端的本地模型进行参数进行更新了:
更新完毕后将最新的参数传到服务器进行聚合:
然后重复上述步骤。
简单总结下Per-FedAvg:
1. 服务器初始化模型。
2. 服务器选择一部分客户端发送模型。
3. 对被选中的客户端来讲,需要进行
轮本地更新,在每一轮本地更新中:首先选择一批数据计算损失函数
的梯度,然后进行一步梯度下降得到元函数
;然后再选择一批数据对
进行梯度下降得到更新后的元函数。
4. 客户端将更新好的元函数上传到服务器进行聚合。
5. 服务器将更新后的模型发往被选中的客户端,然后重复上述步骤。
经过多轮通信后,我们得到一个初始模型
。
2.2 本地自适应
经过2.1,我们得到了一个初始模型
,然后每个客户端利用该模型在自己本地训练很少的轮数,就可以得到一个表现比较好的模型。
本文的重点应该是有关Per-FedAvg收敛性的推导,看着华丽的数学推导,我只能感叹人与人之间的差异之大。
3. 总结
所谓元学习,就是学会如何学习。利用元学习我们可以得到一个初始模型,该初始模型在一批新的数据上进行少数几轮迭代后就能快速收敛,得到一个不错的个性化模型。Per-FedAvg借鉴了这一思想,设计了一个新的优化函数,该优化函数是所有客户端元函数的平均,而客户端元函数则是本地损失函数进行一步梯度下降后得到的。对新的优化函数进行优化后,我们得到的初始模型就能对客户端的本地数据进行快速自适应。
代码实现较为简单,放在下一篇文章!