联邦学习基本算法FedAvg的代码实现

2022-11-01 16:04:13 浏览数 (2)

I. 前言

联邦学习(Federated Learning) 是人工智能的一个新的分支,这项技术是谷歌2016年于论文Communication-Efficient Learning of Deep Networks from Decentralized Data中首次提出。

在我的另一篇公众号文章联邦学习的提出 | 从分散数据通信高效学习深度网络中详细解读了该篇论文,而本篇文章的目的就是利用这篇解读文章对原始论文中的FedAvg方法进行复现。

因此,阅读本文前建议先阅读联邦学习的提出 | 从分散数据通信高效学习深度网络。

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

本文选用的数据集为中国北方某城市九个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。

我们假设这9个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

特征构造

用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。

各个地区应该就如何制定特征集达成一致意见。本文使用的各个地区上的数据的特征是一致的,因此可以直接使用。

III. 联邦学习

1 整体框架

原始论文中提出的FedAvg的框架为:

由于本文中需要利用各个客户端的模型参数来对服务器端的模型参数进行更新,因此本文决定采用numpy搭建一个四层的神经网络模型。模型的具体搭建过程可以参考上一篇文章:从矩阵链式求导的角度来深入理解BP算法(原理 代码),在这篇文章里面我详细介绍了神经网络参数的更新过程,这将有助于理解本文中的模型参数的更新过程。

神经网络由1个输入层、3个隐藏层以及1个输出层组成,激活函数全部采用Sigmoid函数。

神经网络各层间的运算关系,也就是前向传播过程如下所示:

因此,客户端参数更新实际上就是更新四个

2 服务器端

服务器端执行以下步骤:

  • 初始化参数
  • 对第

轮训练来说:首先计算出

,然后随机选择

个客户端,对这

个客户端做如下操作(所有客户端并行执行):更新本地的

得到

。所有客户端更新结束后,将

传到服务器,服务器整合所有

得到最新的全局参数

  • 服务器将最新的

分发给所有客户端,然后进行下一轮的更新。

简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后传给服务器,服务器汇总所有客户端的参数形成自己的参数,然后将汇总的参数再次分发给所有客户端,然后进行下一轮更新。

3 客户端

客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。

4 代码实现

4.1 代初始化

FedAvg中的参数一共有五个:

  • K:客户端数量,本文为9个,也就是9个地区。
  • C:选择率,每一轮通信时都只是选择C * K个客户端。
  • E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
  • B:客户端更新本地模型的参数时,本地数据集batch大小为B。
  • r:服务器端和客户端一共进行r轮通信。

代码实现:

代码语言:javascript复制
class FDL:
    def __init__(self, C, E, B, K, r):
        self.C = C
        self.E = E
        self.B = B
        self.K = K
        self.r = r
        self.w1 = 2 * np.random.random((30, 20)) - 1  # limit to (-1, 1)
        self.w2 = 2 * np.random.random((20, 20)) - 1
        self.w3 = 2 * np.random.random((20, 20)) - 1
        self.w4 = 2 * np.random.random((20, 1)) - 1
        self.nns = []
        for i in range(K):
            nn = BP(clients[i], self.B, self.E)
            self.nns.append(nn)

其中

是服务器端初始化的全局参数,由于服务器端不需要进行反向传播更新参数,因此不需要定义各个隐层以及输出。

中保存的是9个客户端的初始神经网络模型,初始化该模型时需要三个参数:

  • clients[i]:客户端名字,实际上就是csv文件的名字。
  • self.B:本地模型的B。
  • self.E:本地模型的E。

4.2 服务器端

服务器端执行如下代码:

代码语言:javascript复制
def server(self):
    for t in range(self.r):
        print('第', t   1, '轮通信:')
        m = np.max([int(self.C * self.K), 1])
        index = random.sample(range(0, self.K), m)
        for k in index:
            self.client_update(self.nns[k])
        # update w
        s = 0
        
        for j in range(K):
            s  = self.nns[j].len
        w1 = np.zeros((30, 20))
        w2 = np.zeros((20, 20))
        w3 = np.zeros((20, 20))
        w4 = np.zeros((20, 1))
        
        for j in range(K):
            w1  = self.nns[j].w1 * (self.nns[j].len / s)
            w2  = self.nns[j].w2 * (self.nns[j].len / s)
            w3  = self.nns[j].w3 * (self.nns[j].len / s)
            w4  = self.nns[j].w4 * (self.nns[j].len / s)
            
        self.w1, self.w2, self.w3, self.w4 = w1, w2, w3, w4
        # distribute
        for nn in self.nns:
            nn.w1, nn.w2, nn.w3, nn.w4 = self.w1, self.w2, self.w3, self.w4
    nn = BP(clients[0], self.B, self.E)
    nn.w1, nn.w2, nn.w3, nn.w4 = self.w1, self.w2, self.w3, self.w4

    return nn

下面对重要代码进行分析。

  • 客户端的选择
代码语言:javascript复制
m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)

index中存储着m个0~9间的整数,表示被选中的客户端的序号。

  • 客户端的更新
代码语言:javascript复制
for k in index:
    self.client_update(self.nns[k])
  • 服务器端汇总客户端模型的参数
代码语言:javascript复制
s = 0

for j in range(len(index)):
    s  = self.nns[index[j]].len
w1 = np.zeros((30, 20))
w2 = np.zeros((20, 20))
w3 = np.zeros((20, 20))
w4 = np.zeros((20, 1))

for j in range(K):
    w1  = self.nns[j].w1 * (self.nns[j].len / s)
    w2  = self.nns[j].w2 * (self.nns[j].len / s)
    w3  = self.nns[j].w3 * (self.nns[j].len / s)
    w4  = self.nns[j].w4 * (self.nns[j].len / s)
    
self.w1, self.w2, self.w3, self.w4 = w1, w2, w3, w4

服务器端汇总客户端模型参数的具体方式为:

其中

表示第

个客户端的数据量。也就是说,一个客户端的数据越多,它的模型在最终汇总时对全局模型的影响就越大。

当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:

  • normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占的比例。
  • LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占的比例。
  • LS:根据损失与样本数量的乘积所占的比重来决定。

值得注意的是,虽然服务器端每次只选择K个客户端中的m个来进行更新,但在最终却需要汇总所有客户端模型的参数。

  • 将更新后的参数分发给客户端
代码语言:javascript复制
for nn in self.nns:
    nn.w1, nn.w2, nn.w3, nn.w4 = self.w1, self.w2, self.w3, self.w4
  • 模型返回
代码语言:javascript复制
nn = BP(clients[0], self.B, self.E)
nn.w1, nn.w2, nn.w3, nn.w4 = self.w1, self.w2, self.w3, self.w4

return nn

在进行了

轮通信后,重新初始化一个nn模型,将

赋给该nn,作为最终的全局模型。

4.3 客户端

客户端只需要利用本地数据来更新本地模型的参数:

代码语言:javascript复制
@staticmethod
def client_update(nn):  # update nn
    train(nn)
    return nn

其中train为:

代码语言:javascript复制
def train(nn):
    print('training...')
    train_x, train_y, test_x, test_y = nn_seq(nn.file_name, nn.B)
    nn.len = len(train_x)
    batch_size = nn.B
    epochs = nn.E
    batch = int(len(train_x) / batch_size)
    for epoch in range(epochs):
        for i in range(batch):
            start = i * batch_size
            end = start   batch_size
            nn.forward_prop(train_x[start:end], train_y[start:end])
            nn.backward_prop(train_y[start:end])
        print('当前epoch:', epoch, ' error:', np.mean(nn.loss))
    return nn

IV. 实验及结果

本次实验的参数选择如下表所示:

K

C

E

B

r

9

0.5

100

100

5

各个客户端单独训练(训练100轮,batch大小为100)后在本地的测试集上的表现为:

客户端编号

1

2

3

4

5

6

7

8

9

MAPE/%

5.26

4.81

6.09

4.47

3.81

3.71

6.92

4.71

2.99

可以看到,由于各个客户端的数据量都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。

服务器与客户端通信5轮后,服务器上的全局模型在9个客户端测试集上的表现如下所示:

客户端编号

1

2

3

4

5

6

7

8

9

MAPE/%

35.29

32.82

26.73

44.59

33.88

32.77

38.89

37.32

23.42

可以看到,经过联邦学习得到全局模型在各个客户端上表现很差。

分析:原始论文中动辄都是上千次的通信轮数,而本文设定的通信轮数为5,所以效果差是自然的。

为了提升模型精度,将通信轮数 从5变为50(大概训练了一个小时),得到的结果如下所示:

客户端编号

1

2

3

4

5

6

7

8

9

MAPE/%

15.11

19.00

17.84

15.34

24.71

11.08

17.46

21.55

8.52

可以看到,通信轮数 增加后,全局模型在9个客户端测试集上的预测精度有了明显提升。

当然,为了提升精度,我们可以继续增加通信轮数。不过通信轮数越多,模型训练的时间就越长。由于时间关系,这里不再做进一步讨论,有兴趣的可以自己尝试。

0 人点赞