联邦学习论文阅读:FedSEAL: Semi-Supervised Federated Learning with Self-Ensemble Learning and Negative Learning


论文简介

  • 论文名称:FedSEAL: Semi-Supervised Federated Learning with Self-Ensemble Learning and Negative Learning
  • 作者:Jieming Bian,迈阿密大学
  • 论文链接:https://arxiv.org/abs/2110.07829v1
  • 论文来源:Arxiv

论文设定

假定只有server端有数据,client没数据,设定跟FedMatch的情形之一和SemiFL是一样的。
作者针对的是联邦半监督学习,前几轮冷启动的问题。作者将问题转化成了,用client的无标签数据协助server端进行模型训练,提出了FedSEAL框架,主要是基于自集成学习负样本学习。自集成学习用的还是一致性正则+伪标签那一套,用来训练client端;负样本学习是为了扩展client端的无监督数据集,来克服训练初期的冷启动问题。

$p(x,w)$指的是模型参数$w$下输入$x$的预测类概率分布(置信分数向量),是个大小为$M$的向量。$p_m(x,w)$表示在模型参数$w$下输入$x$属于类$m$的预测概率(或置信分数)。
然后通过下式将置信分数向量转化成预测的hard label:




文章用CE来衡量分类模型性能:


论文方法

文章提出的FedSEAL在每轮迭代过程中经历4个步骤:

  1. server端进行全监督学习
  2. 把server端的全监督模型传给client
  3. 抽取客户端进行无监督学习
  4. 上传本地模型



server端全监督学习

第$t$轮FL的时候,服务器用简单的FedAvg来聚合第$t-1$轮抽样的各client的本地模型:



模型聚合以后,将服务器的labeled data弱增广出$\alpha (x)$,服务器端的损失函数如下:




是CE。

服务器除了更新全局模型以外,还给每个类别都计算了一个置信阈值,帮助client在无监督学习过程中过滤数据。设$\tau^t = \tau_1^t, ..., \tau_M^t$为置信阈值向量,$\tau_m^t \in [0,1]$是第$t$轮中第$m$类的置信阈值。阈值计算公式如下:




分母上是客户端验证集上第m类实例的数目,分子是server端验证集m类的所有数据实例的置信度的分数。其意义为:
对于一个新的数据$x$,如果$f(x, w_t)=m$,并且其相关的置信度$p_m (x;w_t )$大于$\tau^t_m$,则期望分类结果正确的概率是很高的。这里随着全局模型每一轮更新,$\tau^t$在学习过程中是发生变化的。置信度阈值向量$\tau_t$随着全局模型一起发给client。

这里作者还阐述了自己的优点:server端的随机初始模型$w_0$并非随机生成的,而是由server端的数据集训练出来的,这样有助于摆脱训练初期伪标签效果不佳的问题,我寻思这不是非常常规的思路吗,也算优点了吗。。。而且万一server端的数据相比于client端是极端non-iid呢,那就不止是没用的问题了。。。

客户端无监督学习

客户端无监督学习的结构如下:




每个客户端通过自集成学习的数据滤波生成一个正数据集,然后通过互补的负样本学习生成一个负的数据集。本地模型在2个数据集上都进行更新。

基于自集成学习的数据过滤
这个思路就是从未标注数据集中选出一个子数据集,使得模型对子数据集中每一个数据生成的伪标签的置信度都非常高。这个步骤称为数据过滤。假设过滤以后的数据集中的伪标签大多是准确的,在这个数据集上进行训练很可能会提高模型的性能。

原始的集中式SSL设置的损失函数如下:




Loss$L(w)$是监督损失,Loss$U(w)$是无监督损失,$\alpha$是权重参数。一开始模型还不可靠,伪标签质量差,模型就用小的$\alpha$,更多依赖标记数据集;后面模型的性能上来了,就加大$\alpha$,更多的依赖进行数据过滤后的无监督数据集。
而SSFL没有$\alpha$来调节权重,只要有一轮模型训练不佳,就会导致下一轮伪标签质量下降,进而影响后面的全局模型性能。
因此,作者想出了一个思路:模型一开始训练的时候,伪标签质量太差,这个时候不能用当前轮的单一全局模型,而是使用前几轮的多个历史模型来生成伪标签。这样每轮全局模型的巨变就能被削弱。这种方法命名为自集成学习(self-ensemble learning),因为这些模型虽然来自不同的轮数,但是学习过程是相同的。

在这篇论文之前,还有论文采用的相似的思路,即使用历史模型进行集成的方法:

Yanbei Chen, Xiatian Zhu, Wei Li, and Shaogang Gong. Semi-supervised learning under class distribution mismatch. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pages 3569–3576, 2020.

具体的实现方案是:客户端$k$使用历史模型${w^0, w^1, ..., w^t}$计算每个实例$i \in D_k$的平均置信分数向量:



而客户端不需要把所有的历史模型都存起来,因为$\bar{a}^t(x_i)$可以进行迭代更新,公式如下:



生成$x_i$的伪标签$y^+_i$为:




其相关置信得分为$\bar{p}^t_{y^+_i}(x_i)$,客户端使用服务器计算的置信阈值向量$\tau$来过滤本地的数据集:


互补负样本学习
前面的自集成学习有一个问题,那就是初始几轮因为没有历史模型,模型的精度较低,导致有大量的错误伪标签的实例被过滤到子数据集$D^{t,+}_k$中。为了给客户的监督学习提供更多有用的信息(特别是在第一轮训练的时候),作者引入互补负样本学习(complementary negative learning),这个方法一开始是针对噪声标签提出的。基本原理是:尽管很难将一个实例分类成正确的类,但是可以很容易把实例从错误的类中排除。因此可以通过为实例分配互补的标签(除真类以外的类),引入更多的信息来更新本地模型。最重要的是这个方法可以纠正子数据集$D_k^{t,+}$中包含具有错误伪标签的实例的负面影响,来提高局部训练性能。

因此作者构建了另一个子数据集$D_k^{t,+}$,包含带有互补标签的数据实例。上标“+”表示(伪)标签是真正的类,“-”表示(互补的)标签应该是除了真类以外的任何类。具体来说,对于每个实例$i \in D_k$,客户端k找到一组类,其相应的置信度得分通过历史平均(公式8)低于一个预先确定的小阈值$\theta$,然后从这个低于小阈值$\theta$的类集合中随机选择一个作为实例$i$的互补标签,即:




由此而到互补数据集:



的条件是实例在$D_k^{t,+}$中,不在$D_k^{t,-}$中。

本地模型更新:
作者用正损失分量和负损失分量组成的客户端损失函数来更新局部模型。




$\lambda>0$是权重参数。正损失分量$L_k^+(w)$是基于伪标签和在输入数据上的强数据增强$A(·)$计算出来的,从而将一致性正则跟伪标签相结合。作者使用的强增强方法是随机增强。
在前几轮训练中,互不标记的正确率比未标记要高得多,所以选择权重$\lambda$较小,来减少由于伪标签标记错误而造成的风险。在后面的几轮里,当历史集成以更高的准确率生成伪标签时,$\lambda$逐渐变大,从而使客户端损失函数更加强调正损失分量。

完整的算法如下:



实验验证

数据集:Fashion-MNIST、CIFAR-10、SVHN。
Baseline:

  • 只有server端数据的SL:视为lower bound。
  • FedAvg-SL:假设所有客户端的数据都被标记,用标准FedAvg训练模型,视为upper bound。(怎么感觉跟论文的server有标签的设定没有可比性啊?)
  • FedAvg+FixMatch
  • FedAvg+UDA
  • FedMatch
  • FedRGD

(没有跟SemiFL进行对比哈哈哈,一开始写的时候应该还没有SemiFL,v2的时候有了SemiFL但没有添加实验)

与FL+SSL方法对比




不明白他这个non-iid是怎么设置的,这他也没讲

与SSFL对比

跟FedMatch和FedRGD进行了对比,在CIFAR-10上进行。
因为FedMatch那篇文章的实验部分参数设置有错误(这里作者跟FedMatch作者核查过了),FedMatch的实际性能比FedMatch论文里阐述的性能还要好,所以改正过的FedMatch用FedMatch (corrected)表示。
另外因为FedMatch没有在Server端进行数据增强,因此为了公平起见,作者也删除了Server-SL的数据增强,记为Server-SL w/o DA。




结果是FedMatch和FedRGD都比baseline还低。

negative learning的影响

这张图展示了正过滤数据集$D_k^{t,+}$的正确伪标签的比例和负过滤数据集的正确互补标签的比例:




可以看到伪标签的正确率明显低于互补标签的准确率。因此如果监督学习只使用正过滤的数据集,其性能可能会显著降低。

关于模块的有效性,这里作者并没有做消融实验,而是逐个模块进行分析,不知道是不是有猫腻。。。

类级动态置信度阈值的影响

FedSEAL的数据过滤的置信阈值$\tau$每一轮都在更新。而在现有工作中,类似的阈值都是被设置成一个常量,而且所有类都是相同的。为了验证文章置信阈值的有效性,作者进行了实验,下图是FedSEAL得到的10个类的置信分数:




可以看到,如果阈值设置成0.9,那类1,3更有可能通过过滤器,而类246就被过滤器刷下去了,因此经过常数阈值过滤后的正数据集$D_k^{t,+}$会变得非常不平衡,从而降低模型的性能。另一方面,如果把阈值设置成0.5,那所有的类都能过滤到$D_k^{t,+}$中,但像135789这些类就很可能包含错误的实例。

自集成学习的影响

为了验证自集成学习的有效性,作者这里把多个历史模型得到的平均置信度向量$\bar{p}^t_{y^+_i}(x_i)$替换成当前模型得到的置信度向量来进行数据过滤,结果如下图:




这张图是两种方法在Fashion-MNIST上的测试精度收敛曲线。跑了不到20轮往后,FedSEAL的性能都是优于没有自集成学习的算法。这是因为随着集成规模的增加,自集成学习有助于过滤更多具有更高质量伪标签的未标记数据到过滤后的数据集$D_k^{t,+}$中。

论文总结

这篇论文关注的问题是联邦半监督下前几轮的模型伪标签质量不佳的问题。这个问题在client端无标签的情况下会更严重,因为client端将会完全仰仗server端的全局模型而且会一直迭代,但凡全局模型有一轮生成的伪标签效果不好,就会导致后面的模型质量越来越差。
为了解决这个问题,作者可谓是无所不用其极,把联邦、半监督两个领域缓解训练初期模型质量问题的方法都用上了,包括但不限于:

  1. 用server端的数据来训练一个SL模型,作为初始模型,避免了初始模型$w_0$的随机化。
  2. 用多轮历史模型的聚合来得到当前轮数的全局模型,避免了全局模型“一轮变差,全局完蛋”的情况。
  3. 由于(2)需要历史模型,考虑到前几轮连历史模型都没有,作者又用类的置信阈值进行筛选,在每个client端的数据集上把高置信度的样本选出来,组成正向的子数据集,进行训练。
  4. 为了辅助(3),增加数据集的信息,作者又构建了一个负样本数据集来进行负样本学习。

叠了这么多方法,文章都能当做SSFL初始化问题的综述了。。

但是这篇论文一直挂在Arxiv上没有投出去也是有原因的,实验设定交代的比较模糊,比如Non-iid的数据集是怎么分割的,堆叠了这么多模块但为什么不做消融实验,都感觉有点猫腻。但是作为学习者的角度,看完这篇论文收获还是很大的,知道了有那么多解决训练初期伪标签训练问题的方法。

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 联邦学习论文阅读:FedSEAL: Semi-Supervised Federated Learning with Self-Ensemble Learning and Negative Learning


Make Everyday Count