I. 前言
Per-FedAvg的原理请见:arXiv | Per-FedAvg:一种联邦元学习方法。
II. 数据介绍
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
III. Per-FedAvg
Per-FedAvg算法伪代码:
1. 服务器端
服务器端和FedAvg一致,这里不再详细介绍了,可以看看前面几篇文章。
2. 客户端
对于每个客户端,我们定义它的元函数
:
为了在本地训练中对元函数进行更新,我们需要计算其梯度:
代码实现如下:
代码语言:javascript复制def train(args, model):
model.train()
Dtr, Dte, m, n = nn_seq(model.name, args.B)
model.len = len(Dtr)
print('training...')
data = [x for x in iter(Dtr)]
for epoch in range(args.E):
model = one_step(args, data, model, lr=args.alpha)
model = one_step(args, data, model, lr=args.beta)
return model
def one_step(args, data, model, lr):
ind = np.random.randint(0, high=len(data), size=None, dtype=int)
seq, label = data[ind]
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss().to(args.device)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
3. 本地梯度下降
得到初始模型后,需要在本地进行1轮迭代更新:
代码语言:javascript复制def local_adaptation(args, model):
model.train()
Dtr, Dte = nn_seq_wind(model.name, 50)
optimizer = torch.optim.Adam(model.parameters(), lr=args.alpha)
loss_function = nn.MSELoss().to(args.device)
loss = 0
for epoch in range(1):
for seq, label in Dtr:
seq, label = seq.to(args.device), label.to(args.device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print('local_adaptation loss', loss.item())
return model
IV. 完整代码
完整代码及数据:https://github.com/ki-ljl/Per-FedAvg,点击阅读原文即可跳转至代码下载界面。