半监督学习的概述与思考,及其在联邦场景下的应用(续)

2022-03-30 11:17:01 浏览数 (1)

上一文(【科普】半监督学习的概述与思考,及其在联邦学习场景下的应用)中,我们主要介绍了半监督学习相关的基础知识、方法以及一些SOTA论文,接下来我们将从联邦学习结合半监督学习的角度来进行探讨。

以往的联邦学习工作大多专注于监督学习任务的研究,即要求所有的数据都必须包含相对应的标签,但是在现实场景中本地客户端(数据拥有方)所包含的数据常常大部分甚至全部都是没有相应的标签的。因此,联邦半监督学习有如下两个角度:

1)基于联邦学习的系统配置下,训练半监督学习模型;

2)结合半监督学习等其它技术,解决联邦学习中存在的一些问题

接下来我们基于若干论文来进行联邦半监督学习的解读。

《Federated Semi-Supervised Learning with Inter-Client Consistency & Disjoint Learning》

本论文(链接1)主要是对联邦半监督学习的应用场景进行分类。具体来说主要分为两个应用场景

1)Labels-at-Client Scenario:参与模型训练的带标签数据和无标签数据均存放于本地客户端,即本地客户端执行标准的半监督学习训练,但是本地数据需要标注,因此开销可能过大;

2)Labels-at-Server Scenario:参与模型训练的带标签数据被存放在服务端中,而大量的无标签数据却存放在本地客户端中,即带标签的监督学习过程和无标签的无监督学习过程将分别在服务器端和客户端分开进行,但是服务器端数据的标注需要专业人士参与。具体如下图所示。

图1:联邦半监督学习场景分类(链接1

《Benchmarking Semi-supervised Federated Learning》

本论文(链接2)主要基于客户端包含无标签数据,服务器端包含有限数量的带标签数据的场景下进行讨论,其主要贡献在于:

1)提出SSFL架构,并提出一种研究Non-IID数据分布的原则方法:一种衡量客户端间类分布差异的指标 Metric R for non-iid level;

2)消融实验:研究各种因子对于SSFL性能的影响;

3)验证了BN和GN这两种normalization方法的优劣性;

4)提出一种grouping-based model averaging方法来加快联邦学习全局模型的收敛速率。

无标签数据处理:主要基于半监督学习中的自洽正则化思想,具体而言是对无标签数据进行数据增强前后模型预测结果应当一致,从而充分利用客户端的无标签数据。同时,还可以对无标签数据进行弱增强,然后模型预测结果作为伪标签,再对原数据进行强增广作为数据并对数据和标签进行交叉熵损失函数计算。具体如下图所示。

Metric R指标:提出一种研究Non-IID数据分布的原则方法,即一种衡量客户端间类分布差异的指标 Metric R for Non-IID level,具体而言论文是通过计算不同类之间的L1距离从而进行衡量。具体如下所示。

模型聚合方法:作者考虑到原先的聚合方式(服务器端平均聚合所有被采样的客户端),不同用户之间较大的模型差异性会极大的减小模型训练的速率,因此作者提出grouping-based model averaging方法来加快联邦学习全局模型的收敛速率。其思路是在客户端和服务器之间加入若干个组,先对客户端聚合参数到组,然后组聚合参数到服务器。具体如下所示:

Normalization影响:由于BN可能在不同客户端之间存在统计上的差异,同时最新研究结果表明GN优于BN在Non-IID情况下的监督任务的联邦学习,因此考虑使用GN。针对不同的Normalization方式如下图所示:

SSFL总体架构:SSFL总体架构是基于利用无标签数据、分组聚合、以及数据处理等共同构建而成,如下图所示:

图2:SSFL总体架构(链接2

SSFL总结:

1)关于数据增广,是否还存在不同的数据增广方法来帮助无标签数据进行训练呢,以及是否会对联邦学习模型产生影响?我觉得这是一个可以继续研究的点;

2)关于模型聚合方面,各个组之间还是采用的平均聚合方法,因此可以探讨一下各个分组之间的模型聚合方式;

3)关于BN和GN方法,是否可以随机交替使用GN和BN比单一使用GN或者BN拥有更好的效果?

《Distillation-Based Semi-Supervised Federated Learning for Communication-Efficient Collaborative Training with Non-IID Private Data》

该论文(链接3)的主要目标是:提高通讯效率,减少通讯开销,同时保证模型性能尽可能接近或高于联邦学习基准。作者主要设计了两个算法

1)DS-FL算法:使用一个全局共享的无标签数据集,相当于利用数据扩增的效果来提高模型性能。具体而言是将本地输出记为local logit,服务器端聚合输出记为global logit,然后利用服务器端global logit对本地local logit进行蒸馏;

2)ERA算法:考虑异构性数据集会导致样本信息的模糊性和较慢的收敛速率,论文提出ERA聚合算法,使得聚合的输出更加尖锐化,这样带来的好处有很多:例如在Non-IID配置下,可以起到加速收敛和模型稳定性作用、ERA算法的另一个附加功能是可以增加模型对于恶意用户所发起的攻击的鲁棒性(熵增相对较低)。

进一步而言,我们可以考虑算法背后的思想。DS-FL算法主要采用知识蒸馏思想:将本地模型看作学生,将local logit在服务器端聚合所得global logit作为教师的知识进行传递。ERA算法采用最小化熵的思想:因为Non-IID 数据分布会导致global logit产生更高的熵,从而导致其表达的信息更加模糊,造成模型收敛减速等问题,所以采用降低熵的做法来优化。

DS-FL算法:具体的DS-FL算法如下图所示,下图(a)代表最基本的联邦学习架构;下图(b)则代表联邦蒸馏算法,具体有如下几步:

1)将本地模型视为student,将除自己外所有客户端输出的聚合看作teacher;

2)本地需要计算每个标签的local loait,服务器要对所有客户端的local logit进行聚合形成global logit;

3)关于logit可以看作为一种统计所得的软标签信息。下图(c)则是代表DS-FL算法,主要就是联邦蒸馏 无标签共享数据集,首先联邦蒸馏操作步骤,利用本地有标签数据先得到模型,然后再对共享的unlabeled data(无标签干扰)进行蒸馏相关操作,提升模型性能。利用无标签共享数据集可以去除标签干扰从而对模型进行微调。

图3:DS-FL算法(链接3

ERA算法:ERA(Entropy Reduction Aggergation)算法目的在于加速收敛,增加模型的稳定性以及对恶意攻击的鲁棒性,其思想是有意的减少熵(或者说减少干扰造成的熵增幅度),具体如下图所示,通过对softmax函数增加温度因子T以锐化logits:

DS-FL总结:

1)关于蒸馏思想,由于最后蒸馏的目标是logit,怎样降低Non-IID产生的影响,即需要降低各个客户端之间的特征差异性;

2)关于标签,考虑到各个设备的特性,未来的工作将包括开发如何对于全局模型输出的标签进行聚合的方法,local logit在服务器端聚合所得global logit,例如可以增强可靠或高性能客户端上传的模型logit对global logit的影响。

《Federated Semi-Supervised Learning with Inter-Client Consistency & Disjoint Learning》

以往的Federated Learning仅研究监督学习,即所有的样本都包含标签,但真实场景是本地客户端中大部分或所有数据都是没有相应标签的。标签缺乏的原因在于:

1)较高的标注开销,用户不可能有时间和精力帮助服务者标记大量的数据;

2)标注者缺乏相关专业领域知识,难以完成标注任务:如纠正坐姿的APP等。

本论文(链接4)主要贡献在于:

1)系统地提出了联邦半监督学习的应用场景;

2)提出了应对FSSL应用场景的联邦学习框架FedMatch;

3)通过最大化各客户端模型之间的共识从无标签信息中进行学习;

4)执行模型参数的分解(分别训练带标签和无标签数据),以减少有监督和无监督任务之间的干扰以及通信成本。

FedMatch的核心思想在于一致性正则化(Consistency Regularization)技术,一种在半监督学习配置下从无监督数据中学习知识的技术,核心思想:对于一个输入,即使受到微小的干扰,模型的预测结果应该是一致的。因此,作者提出:本地模型输出结果和服务器选出的模型(大家共识的)的结果尽可能相似。对于如何选择大家共识的模型(客户端)作者提出基于可靠性的聚合方法(Reliability-based Aggregation)。

一种考虑本地模型可靠性(是否包含可靠的知识)的聚合方法,该可靠性还用于服务器进行helper agents的选择,即共识。传统的联邦学习算法一般采用FedAvg模型所设计的聚合方法来对模型进行聚合,FedAvg所采用的模型聚合方式就是按照各个客户端所具备的数据量占总训练数据量的比例来对各个参与聚合的本地模型进行加权平均。

基于此,FedMatch算法设计了一种考虑本地模型可靠性的聚合方法来对各个本地模型进行聚合,此处的可靠性指的是模型从数据中所学到知识对于解决相关任务的可靠性程度,具体公式如下所示,准确性越大则权重越高:

在服务器端的模型可靠性计算也为之前所讲解的一致性正则化损失函数提供了helper agent的选择机制,即helper agent是每轮各个本地模型的集合中可靠性最大的H个模型的集合。此处的可靠性衡量其实就是一种各个客户端间所达成的共识机制

进一步的作者还分别对客户端标签场景和服务器端标签场景设计算法。

图4:客户端标签场景(链接4

图5:服务器端标签场景(链接4

总结

从上面几篇论文我们可以看到,联邦无监督方法目前任然处于起步阶段,理论上面的研究相对匮乏,大多论文是借助半监督学习来用于联邦学习场景下的客户端或服务器端,因此可以从如何利用无标签数据、改进联邦学习模型聚合算法以及模型对无标签数据的影响(下游任务微调等)进一步来开展研究

参考链接

[1] Jeong W, Yoon J, Yang E, et al. Federated semi-supervised learning with inter-client consistency & disjoint learning[J]. arXiv preprint arXiv:2006.12097, 2020.

[2] Zhang, Z., Yao, Z., Yang, Y., Yan, Y., Gonzalez, J. E., & Mahoney, M. W. (2020). Benchmarking semi-supervised federated learning. arXiv preprint arXiv:2008.11364, 17.

[3] Itahara S, Nishio T, Koda Y, et al. Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data[J]. arXiv preprint arXiv:2008.06180, 2020.

[4] Jeong, W., Yoon, J., Yang, E., & Hwang, S. J. (2020). Federated semi-supervised learning with inter-client consistency & disjoint learning. arXiv preprint arXiv:2006.12097.

END

0 人点赞