半监督学习论文阅读:MixMatch


论文简介

  • 论文名称:MixMatch: A Holistic Approach to Semi-Supervised Learning
  • 作者:David Berthelot,谷歌
  • 论文链接:https://arxiv.org/abs/1905.02249
  • 论文来源:NIPS 2020

半监督学习的3种思路

方案1:Consistency Regulariztion

对未标记的数据进行增广(比如对图片进行平移、缩放、旋转、裁剪、改亮度加噪声之类的),产生的新数据输入到模型里,其输出的结果应该保持consistency。这个规则加入到损失函数中:

$$ \| \pmb{p}_{model} (y| Augment(x);\theta) - \pmb{p}_{model}(y|Augment(x);\theta) \|^2_2 $$

$x$是未标记数据,$Augment(x)$表示对数据增广后产生的新数据(随机操作,两个$Augment(x)$的输出不是相同的),$\theta$是模型参数,$y$是预测结果。这样求得一个L2损失项来约束模型,对同一个图像做增广得到的所有新数据,其预测都应该是一致的。

方案2:Entropy Minimization

半监督方法的一个共识就是强迫分类器对未标记数据做出低熵预测,以保证分类器的分类边界不应该穿过边际分布的高密度区域。

方案3:Traditional Regularization

MixMatch使用Adam优化器+weight decay来提高模型的泛化能力。

Mixmatch方案

  1. 对一个batch的标记数据$\chi$ 和一个batch的未标记数据$u$做数据增广,分别得到一个batch的增广数据$\chi'$和$K$个batch的$u'$。

    $$ \chi' , u' = MixMatch=(\chi, u, T, K, \alpha) $$

    其中$T, K, \alpha$是超参数。

具体的算法流程如下:



`
对于标记标签,做一次增广,标签不变;对于未标记数据,做K次增广,然后输入到分类器,得到平均分类概率,然后输入到temperature sharpening算法中,得到未标记数据的猜测标签q。此时$\hat{\chi}$增广了一个batch,$\hat{u}$增广了$k$个batch。然后把$\hat{\chi}$和$\hat{u}$打乱重排,得到数据集$W$。最后输出两部分:

  • $\chi'$:将$\hat{\chi}$和$W$做Mixup()的一个batch的标记数据
  • $u'$:$\hat{u}$和$W$做Mixup()以后的K个batch的无标记增广数据$u'$。

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 半监督学习论文阅读:MixMatch


Make Everyday Count