半监督论文阅读:FixMatch


论文简介

  • 论文名称:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
  • 作者:Kihyuk Sohn, 谷歌
  • 论文链接:https://arxiv.org/abs/2001.07685
  • 论文来源:NIPS 2020

文章特点:这篇论文算是SSL领域的SOTA,在实验验证和消融实验当中把半监督领域有用没用的方法都试了一遍,然后组合出来了个FixMatch。整篇基本没有理论推导,全是实验,但实验做的非常完善,特别是消融实验的部分,各方面都考虑到了,对一个步骤都选取了不同方法进行了对比验证,值得学习一下。

论文概述

FixMatch方法的大致流程:首先基于模型对弱数据增强的无标签图片进行预测,来生成伪标签,然后只选用置信度高的样本(设置了一个阈值)。对于这批高置信度样本,基于一致性正则化思想,来强迫模型对强数据增强的同一批图片获得相同的预测。流程比较简单,但效果很好,在250个标注数据的CIFAR-10数据集准确率达到了94.93%,在40个标注数据下,准确率达到了88.61%。



Pseudo-labeling

低密度分离和熵正则化的假设下,Pseudo-labeling对模型预测进行门槛筛选和硬化,对无标记数据生成伪标签,然后微调模型去拟合所获得的伪标签。具体来说,伪标签方法只保留模型预测的最大概率高于阈值$\tau$的无标记样本,并将样本指定为最大概率的类别,即标签硬化

模型对无标记数据的预测值:

$$ q_b = p_m(y|u_b) $$

根据预测值,生成伪标签:

$$ \hat{q_b} = arg \max(q_b) $$

然后微调模型,迫使模型去拟合硬化后(也就是筛选出最大概率高于阈值$\tau$的样本)的伪标签,损失函数形式化如下:



摘自知乎:伪标签方法迫使模型预测硬化的这种思路,可以从熵最小化的角度来理解。
熵最小正则化是鼓励模型在无标签数据上给出低熵的类别概率分布,而Pseudo-Labeling的论文指出,熵最小正则化是符合低密度分离的假设的,可以在不对决策边界进行显式建模的情况下,鼓励低密度分离。

Consistency Regularization

在平滑假设下,一致性正则化方法认为模型对于经过扰动的数据,应当输出相似的预测结果,因此,定义一致性正则化的损失如下:




一致性正则中,对模型施加扰动的方法:

  • Mean Teacher
  • Temporal Ensemble
  • Noisy Student
    从两个角度理解:
  • 对模型施加扰动在效果上类似于对输出隐向量施加噪音,从而类似数据扰动
  • 魔性扰动而保持预测值一致,符合低密度分离假设。

即对决策边界进行轻微扰动的情况下,数据的整体预测结果保持不变。

FixMatch方法

FixMatch方法结合了Pseudo-Labeling和Consistency Regularization的思路:基于弱数据增强(e.g.,翻转和平移)的无标签数据来生成伪标签,然后把伪标签当做目标,使得模型对强数据增强(Cutout,CTAugment和RandAugment)的同一图片生成一致性的预测。

损失函数

FixMatch损失函数由有监督损失$l_s$和无监督损失$l_u$两部分组成。
$l_s$是在弱增强的有监督样本上的标准分类损失:




对于无监督样本,FixMatch对样本进行弱增强以后,获得伪标签$\hat{q_b}$,然后对同一图像的强增强样本施加一致性损失。这里作者使用的交叉熵来衡量一致性。



最终的损失函数为:

$$ l_s + \lambda_u l_u $$

超参数$\lambda_u$指定了无监督损失的权重。
作者提到,在现代SSL算法中,一般会在训练期间逐步增加无监督损失项的权重$\lambda_u$。而FixMatch不需要这样。这很可能是因为,在训练的初期,$max(q_b)$通常小于阈值$\tau$,因此模型没生成多少伪标签;而随着训练的进行,模型的预测越来越自信,伪标签生成的数量就会越来越多。

算法流程



数据增强策略

弱增强:翻转+平移
强增强:RandAugment、CTAugment两种策略+Cutout策略的结合

RandAugment和CTAugment是AutoAugment的两种变体,AutoAugment使用强化学习手段寻找最优的图像变换组合,但它要求有一定数量的带标签数据集,半监督学习场景下有标签样本量不能满足这个条件。

实验验证

作者对比了几种经典的一致性正则方法:$\Pi$-model,Mean Teacher,MixMatch,UDA和ReMixMatch;以及伪标签算法Pseudo-Labeling,实验结果如下:



可以看到,FixMatch在大部分数据集上获得了相近或更优的准确率。作者还强调,虽然在一些实验设置比如CIFAR-100数据及上,ReMixMatch稍微胜出,但是FixMatch方法更为简单。如果将FixMatch与ReMixMatch方法中的Distribution Alignment方法移植到FixMatch中,400个标签上能达到40.14%的准确率,优于ReMixMatch的44.28%。

单例监督学习

作者还实验了一种极端情况:在数据集的每个类上随机选择1个样本来进行模型训练。但作者在实验中发现结果的方差太大,因为每个样本的质量不一。于是作者参考了文献:

N. Carlini, Ú. Erlingsson, and N. Papernot. Distribution density, tails, and outliers in machine learning: Metrics and applications. arXiv preprint arXiv:1910.13427, 2019. 7, 16
选了每个类中最有代表性的样本,进行了4次训练,中位数为78%。

消融实验

Sharpening vs Thresholding

在主流半监督学习中,有两种Pseudo Labeling的构造方法。一种是FixMatch方法采用的argmax硬化算法,也就是Thresholding;另一种是UDA算法采用的Sharpening,也就是“锐化”。锐化算法是概率分布更加极端,如下:




在温度$T$趋近于0的情况下,Sharpening方法退化成argmax方法。两种方法都符合熵最小正则化的思路。

作者先比较了阈值硬化方法中阈值的影响:




随着阈值降低,错误率降低并逐渐饱和。
作者还实验了Sharpening和Thesholding的组合策略。在阈值为0($\tau=0$)的情况下,随着锐化温度$T$降低,模型的锐化强度上升,错误率逐渐降低。但当阈值较大($\tau \in {0.8, 0.95}$)的情况下,两者的共同影响难以预测。最终,FixMatch只采用了阈值硬化(Thresholding)的方法。

数据增强策略

数据增强是FixMatch算法中的重要步骤。作者先进行了CTAugment和Cutout的消融实验:




可以看到,Cutout和CTAugment缺少任何一个都会导致误差增加。

作者还实验了生成伪标签和一致性预测的时候,强弱增强的不同组合。实验表明:

  • 标签预测如果用弱增强取代强增强,模型的训练会变得不稳定(精度从45%降到12%);
  • 标签预测如果用不增强取代强增强,会导致模型对伪标签过拟合;
  • 生成伪标签如果用强增强取代弱增强,模型会在训练早期出现分歧。

伪标签的数量和质量之间的权衡

为了探讨FixMatch中阈值的影响,作者提出了两个测量指标:impurity和mask rate,具体形式如下:




impurity用于衡量阈值所规定的样本预测错误率(直观理解就是达到阈值的样本里的错误率),Mask Rate用于衡量符合阈值要求的样本的数量。
作者在测试集上计算了两个指标随阈值$\tau$的变化即算法的误差率如下:



随着$\tau$增加,impurity降低,mask rate降低,作者的解释是:阈值越高,预测更准确,但可使用的无标签样本的数目也越少。错误的伪标签会污染模型,这也就是confirmation bias问题。同时,我们也希望有足够多的无标签样本,来挖掘数据的分布规律。最终FixMatch所使用的阈值是$\tau=0.95$。

Optimizer




模型对动量(momentum)比较敏感,$\beta$太大时,模型不收敛;$\beta$较小时,模型效果也还不错,而且此时可以通过提高学习率$\eta$来提高性能,但是依然不如$\beta=0.9$时的性能好。Nesterov动量的错误率比标准SGD略低,但差异不明显。

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 半监督论文阅读:FixMatch


Make Everyday Count