联邦元学习算法Per-FedAvg的PyTorch实现

2022-11-09 14:54:08 浏览数 (2)

I. 前言

Per-FedAvg的原理请见:arXiv | Per-FedAvg:一种联邦元学习方法。

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

III. Per-FedAvg

Per-FedAvg算法伪代码:

1. 服务器端

服务器端和FedAvg一致,这里不再详细介绍了,可以看看前面几篇文章。

2. 客户端

对于每个客户端,我们定义它的元函数

为了在本地训练中对元函数进行更新,我们需要计算其梯度:

代码实现如下:

代码语言:javascript复制
def train(args, model):
    model.train()
    Dtr, Dte, m, n = nn_seq(model.name, args.B)
    model.len = len(Dtr)
    print('training...')
    data = [x for x in iter(Dtr)]
    for epoch in range(args.E):
        model = one_step(args, data, model, lr=args.alpha)
        model = one_step(args, data, model, lr=args.beta)

    return model


def one_step(args, data, model, lr):
    ind = np.random.randint(0, high=len(data), size=None, dtype=int)
    seq, label = data[ind]
    seq = seq.to(args.device)
    label = label.to(args.device)
    y_pred = model(seq)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_function = nn.MSELoss().to(args.device)
    loss = loss_function(y_pred, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return model

3. 本地梯度下降

得到初始模型后,需要在本地进行1轮迭代更新:

代码语言:javascript复制
def local_adaptation(args, model):
    model.train()
    Dtr, Dte = nn_seq_wind(model.name, 50)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.alpha)
    loss_function = nn.MSELoss().to(args.device)
    loss = 0
    for epoch in range(1):
        for seq, label in Dtr:
            seq, label = seq.to(args.device), label.to(args.device)
            y_pred = model(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print('local_adaptation loss', loss.item())

    return model

IV. 完整代码

完整代码及数据:https://github.com/ki-ljl/Per-FedAvg,点击阅读原文即可跳转至代码下载界面。

0 人点赞