联邦学习论文阅读:SemiFL: Semi-Supervised Federated Learning for Unlabeled Clients with Alternate Training


论文简介

  • 论文名称:SemiFL: Semi-Supervised Federated Learning for Unlabeled Clients with Alternate Training
  • 作者:Enmao Diao, 杜克大学
  • 论文链接:https://arxiv.org/abs/2106.01432
  • 论文来源:NIPS 2022

论文贡献

这篇论文相当于是半监督联邦学习领域中的SOTA。看这篇论文的意思,半监督联邦的通用情况都是服务器带标记数据、客户端不带标记数据。这篇论文通过强数据增强、带标记服务器和无标记客户端交替训练来提高性能。

论文动机

作者先进行了实验,把SSL的FixMatch方法和联邦学习的FedAvg和FedSGD组合起来,发现只有FixMatch+FedSGD组合起来能够work,而FedSGD需要批量的模型聚合,因此会有极大的communication cost。




而文章论证了Fedmatch效果不好的原因,是模型生成的伪标签质量越来越差。本文旨在解决这一问题。

强数据增强

strong data argumentation是SSL的sota方法Fixmatch的关键部分:将高质量的数据点(比如图片数据)映射成低级别的数据,把低质量数据和高置信度的标签看成一组标记数据进行模型训练,这样能给模型提供更多的训练机会。



交替训练过程

FedMatch的方法(缺点)




这是FedMatch的方法,即将SSL方法Fixmatch和通信高效的FL方法FedAvg简单组合在一起。(这个方法快被作者喷烂了,拉出来鞭尸了好多次= =)
这个方法有几个问题:

  • 客户端本地的数据在训练了几轮以后,生成的伪标签他不见得就会越来越好。
  • 伪标签越来越差 → 本地模型效果越来越差 → 聚合的全局模型效果又会变差 → 全局模型在下一轮会产生效果更差的伪标签,如此恶性循环

作者的方法




优点:

  • 每一轮都会用标记数据进行全局模型的微调,这样全局模型在下一轮会给活跃的客户端分配更好的伪标签。
  • 使用全局模型对客户端进行标签。

SemiFL

作者的方法其实直接看伪代码会清楚很多,论文罗列了将近一页的伪代码,大体分为一个主函数(系统执行程序)和两个子函数(客户端更新程序、服务器端更新程序)

主程序代码:




这里需要注意的是,每一个communication round的开始和结尾,都会用服务器端的标记数据对全局模型的参数进行一个修正。个人感觉是担心客户端本地训练的模型太差,把全局模型搞毁了以后生成的伪标签就会越来越差,导致Fedmatch的恶性循环。所以每一个round都会用标记数据跑一下作为兜底。

服务器端更新




没有什么好说的,就是用labeled data对模型进行更新

客户端更新




先用弱数据增强来生成伪标签,然后分别构建Fixmatch和Mixmatch的数据集。使用Fixmatch和Mixmatch生成的数据集来训练本地模型。

相关论文

SSL

[23] Mixmatch
[32] Fixmatch

联邦半监督

[26] 关于未标记数据对FL影响的一篇survey
[27] FedMatch,用带标记的服务器和未标记的客户端分别分割模型参数
[28] FedRGD,训练和聚合带标记服务器和未标记客户端的模型参数,然后并行组级重加权

一致性正则化

[20]

强数据增强

[21-24]
[25] 强数据增强应用于SSL

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 联邦学习论文阅读:SemiFL: Semi-Supervised Federated Learning for Unlabeled Clients with Alternate Training


Make Everyday Count