论文简介
- 论文名称:Dual Class-Aware Contrastive Federated Semi-Supervised Learning
- 作者:Qi Guo,西安交通大学
- 论文链接:https://arxiv.org/abs/2211.08914
- 论文来源:CVPR 2022
论文简介
FSSL场景设置:客户端要么有完全标记的数据,要么有完全无标记的数据。
论文动机:现有的FSSL方法大多基于伪标签和一致性正则的方法来利用未标记的数据。但是这种方法有两个限制:
- 客户端的本地模型和未标记客户端的本地模型之间有很大的偏差。
- 噪声伪标签会造成确认误差。
作者认为,误差是由标记客户端和未标记客户端的本地模型训练目标存在显著差异造成的。因此,每个客户的训练过程中应该让所有local models朝着一个共同的目标学习。
但因为FSSL的数据分布是non-iid的,所以如果只考虑global data class-aware 分布的时候,他就没法匹配每个特定客户端潜在的特征空间,因而损失客户端在训练过程中在客户端本地的性能。
论文方法
提出了DCCSSL,来缓解本地模型对训练目标的差异导致的误差,以及由噪声伪标签引起的确认误差。
框架流程:
- client在收到global model以后作为初始化进行本地训练,然后把更新后的模型和原型传回服务器
- 服务器采用认证重加权模型聚合(AMA)和认证重加权原型聚合(APA),得到server端的global model和global class prototypes,再发给选定的客户端。
本地训练
局部基础训练模块分别对labeled client进行弱增强的SL,对unlabeled clients进行基于伪标签的一致性学习(用的fixmatch)。
对偶类感知对比(Dual class-aware contrastive)模块同时考虑局部类感知分布和全局类感知分布。对于局部类感知分布,认定同一类别的表示$z$为正对,其余的视为负对。对于全局类感知分布,将来自同一类的全局类prototype视为正对,其他全局类原型作为特定样本的负对。
标记客户端
标记客户端用CE作为监督损失:

另外对于对偶类感知对比模块,定义局部类感知对比损失$L^l_{lcc}$和全局类感知对比损失$L^l_{gcc}$。
$L^l_{lcc}$的计算如下:

第一个sum函数之所以有$2N$项,是因为每个样本要经过两次强增强获得2个特征。$z$是每个图像强增强以后的特征,$\tau$表示温度参数,$S(i)$表示同一类的其他图像的类别索引,$|S(i)|$表示基数,$|S(i)|$表示所有正对。
$L^l_{gcc}$的计算如下:

$C$表示类的总数,$\bar{z}_j$表示server发送的属于第$j$个类的样本的本地潜在特征对齐的第$j$类的全局原型。
标记客户端中的对偶类感知对比损失$L^l_{DCC}$的表述为:

$\lambda_{lcc}$和$\lambda_{gcc}$分别是控制局部类感知和全局类感知的影响的系数因子。
模型总损失如下:

未标记客户端
未标记客户端的训练模块采用的是基于伪标签的一致性正则来作为无监督损失$L^u_{basic}$的基本模块。$L^u_{basic}$公式如下:

其中$T_{thr}$是置信度阈值,H是交叉熵。
未标记客户端跟标记客户端差不多,都是既包含局部类感知对比损失,也包含全局类感知对比损失。未标记客户端的思路就是,既然我们没有样本标签,那我们就假设图像的表示有很大概率是可靠的,那就把同一类的样本拉的更近,把其他类推得更远。
本地类感知对比损失表述为:

$S(i)$表示来自同一类的其他置信度$p>T_{thr}$图像的索引。$|S(i)|$表示其基数,$|S(i)|+1$表示所有的正类样本。
全局类感知对比损失表述为:

$C$表示类的总数,$\bar{z}_j$是第j类的prototype。
双类感知对比损失$L^u_{DCC}$如下式:

unlabeled client的总损失为:

认证-重加权模型聚合(AMA)
为了缓解模型的不均匀性,作者提出了Authentication-reweighted Model Aggregation方法。简单来说就是根据不同client上的认证样本(正确分类的标记客户端/具有高置信度伪标签的未标记客户端的样本)数量,来调整局部模型的权重。
每轮训练过后的聚合公式如下:

$K$是选择的客户端的数量,$A_{i,i\in {1,2,...,K}}$是第i个客户端上的认证样本。
认证-重加权原型聚合(APA)
这步聚合的目的是增强全局类原型的鲁棒性。每次完成本地培训以后,每个客户端使用本地数据(只用验证数据),通过新的本地模型来生成本地类原型。
全局类原型如下:

APA更新了$O$之后,server将新的全局原型发给下一轮选定的客户端。
$v_i=(v^i_0,v^i_1,...,v^i_{C-1})$的认证样本数向量。
算法流程:

实验结果
数据集:CIFAR-10、CIFAR-100、SVHN
backbone:Wide ResNet WRN-16-2

跟一些常见的SL+FL组合做了对比,不过他还是没有比FedSGD的组合方式。比较惊奇的是作者的模型比全监督时候的upperbound还要好,这应该是因为SSL做了数据增广和一致性正则,而upperbound没有做增广。
消融

对lcc和gcc做了消融,证明了两者结合使得模型精度提高最大。lcc的提升效果也不明显。
总结
这篇论文应该是对FedMatch那篇论文的一个思路的延伸,他们都是基于MixMatch来进行变式的。关于GAN相关模块我还没太读懂每个小模型的真正作用,需要再读一遍消化一下;另外这篇论文应该是针对稀缺数据的情况来的,数据越稀缺他所展现的性能就越优越,消融实验也证明了这一点。
Comments | NOTHING