联邦学习论文阅读:神经正切核+联邦学习


论文简介

  • 论文名称:TCT: Convexifying Federated Learning using Bootstrapped Neural Tangent Kernels
  • 作者:Yaodong Yu, 加州大学伯克利分校
  • 论文链接:https://arxiv.org/abs/2207.06343
  • 论文来源:nips 2022

论文贡献

作者分析了联邦学习性能不佳的原因:网络的隐藏层其实学到了有用的特征,但非凸问题的联邦优化使得最终层没有将特征利用好。于是作者进行了改进,通过神经正切核把优化变为凸化问题。

作者提出了一个tarin-convexify-train算法:

  1. 先用现成的联邦算法(如FedAvg)来提取特征;
  2. 使用empirical Neural Tangent Kernel(经验神经正切核,eNTK)来计算模型的凸近似,然后通过梯度校正算法(如SCAFFOLD)来训练最终模型。
  3. 第二阶段训练的时候,冻结第一阶段学习到的特征,并且拟合一个线性模型。

相关论文

SCAFFOLD、FedDyn:通过控制变量来纠正FedAvg对单个用户的偏差。
针对“对非凸模型取平均”造成的性能损失,有如下方案:

  • 在取平均值前先学习客户端模型权重之间的映射:
    Model fusion via optimal transport.2020
    Fed2: Feature-aligned federated learning.2021
  • 用知识蒸馏来代替平均
    Towards model agnostic federated learning using knowledge distillation.2021
  • 尝试对齐客户端模型的内部表征
    Adaptive federated learning in resource constrained edge computing systems.2019b
    Model-contrastive federated learning.2021b
    Fedproto: Federated prototype learning over heterogeneous devices.2021
  • 通过对非凸模型取平均来改善性能:
    Averaging weights leads to wider optima and better generalization.2018
    Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time.2022

关于Neural Tangent Kernels(NTK,神经正切核)

NTK一开始是研究无限宽网络用的。NTK+MSE很接近预训练网络的微调。因此NTK没法学习到特征,但能学习“模型如何比中层/最终层的激活函数更好的利用学习到的特征”。

论文动机:模型非凸性造成的影响

以两种高度异构的方式分割CIFAR-10数据集:

  • 每个客户端含有2个类,记为$\#C=2$
  • 以$\alpha=0.1$的狄利克雷分布来分割样本,记为$\alpha=0.1$

梯度校正算法的不充分性

有几篇论文认为数据异构导致模型收敛速度变慢,因为每个客户端的局部最优都不一样,没有模型能满足所有客户端的局部最优,因此客户端之间的更新就会相互竞争。导致收敛速度变慢。

相关论文
Scaffold: Stochastic controlled averaging for federated learning.
Adaptive federated optimization. In International Conference on Learning Representations, 2021.
On the unreasonable effectiveness of federated averaging with heterogeneous data.2022

这种关系可以通过客户端之间更新的方差来发现(即客户端梯度异质性)。

A field guide to federated optimization. arXiv preprint arXiv:2107.06917, 2021.

梯度校正办法指出,凸损失和非凸损失都不受梯度异质性的影响。

Scaffold: Stochastic controlled averaging for federated learning. 2020b.
Federated learning based on dynamic regularization. 2021.

分别使用SCAFFOLD和FedAVG,在存在数据异构和不存在数据异构的情况下,使用线性模型(损失情况为凸)和ResNet-18(损失情况为非凸)进行对比:




左图可以看到:凸优化情况下,SCAFFOLD相比于FedAvg来说受数据异构的影响较小。
而对比右图:非凸情况下,SCAFFOLD和FedAvg面对数据异构,表现得都不咋地。

而从右图可以看到,训练精度和测试精度非常接近,即问题没有出在泛化上(没有对训练数据过拟合),而是出现在了优化上。

冻结层数,深入探究

FedAvg+ResNet的baseline是56.9%,集中学习的baseline是91.9%。作者为了研究这35%的精度损失都去了哪,作者做了如下实验:
对于一个用FedAvg训练好的$l$层模型,我们把前$l-1$层都冻结,然后用集中式学习来重训练输出层(相当于对最后一层的激活做了凸逻辑回归)。就这样一个简单的操作准确率直接从56.9%提升到77.9%!!
因此,在FedAvg和集中式模型之间的35%的精度差距中,有差不多21%是因为线性输出层的优化不佳。

作者又进行了若干实验如下图:




分别使用参数随机初始化和FedAvg来初始化一个模型,然后对最后的$l$层进行再训练,冻结住前$7-l$层。FedAvg和参数随机初始化之间的性能差异,即为FedAvg在前$7-l$层学到的信息内容。

  • 从$l=1$的情况来看(冻结前6层,重训练最后一层),精度有42.6%的差异,足以说明模型的初始层已经学到了有用的特征。
  • 从$l=6$的情况来看(冻结第1层,重训练后面6层),精度依然有差异,这就说明模型从低阶层就开始学到特征了。但因为低阶都是边缘检测的一些特征,所以特征相关性不是很强。

论文方法

先学好特征

从前面的动机中就有了一个Idea:既然FedAvg学到的特征是好的,那就可以把FedAvg学到的好特征和凸联邦优化结合起来。
在解决优化问题的凸化情况之前,先训练上几轮,引导FedAvg先把特征学好。

计算每个客户端的神经正切核(eNTK)

为了避免神经网络的非凸性带来的挑战,首先通过“线性化”来近似一个神经网络。

对于神经网络$f(\cdot; \theta_0)$,可以用empirical neural tangent kernel(经验神经正切核,eNTK)来近似$\theta_0$处的神经网络函数值:




近似完之后,$f(x;\theta)$就变成了关于特征向量$(f(x;\theta_0),\frac{\partial}{\partial \theta} f(x;\theta_0))$ 线性函数。神经网络的非凸优化问题就变成了关于特征向量的凸线性回归。

将非凸问题转化为凸优化问题的步骤:

  1. 重初始化$\theta_0$的最后一层。
  2. 对于每个数据点$x$,计算梯度$\phi_{eNTK} (x;\theta_0):= \frac{\partial}{\partial \theta} f_1(x;theta_0)$。
  3. 对于每个数据$x$,将其$\phi_{eNTK} (x;\theta_0)$进行自采样,得到eNTK的降维表示。

用凸线性回归来逼近非凸优化。

上面的eNTK推导过后,我们就可以用FedAvg训练的模型中提取每个客户端输入的eNTK表示。把每个客户端的eNTK表示以联邦学习的方法来拟合一个过度参数化的线性模型。

参数说明:

  • 用$z_i^k$表示第$k$个客户端的第$i$个样本的eNTK特征。
  • $K$表示客户端数量。
  • $Y_i^k$表示独热编码标签。
  • $n_k$表示第$k$个客户端的数据点数量。
  • $n := \sum_{k \in [K]} n_k$表示所有客户端的数据总数。
  • $p_k := n_k/n$表示每个客户端数据点的持有比例。

由此,可以利用凸线性回归来逼近神经网络的非凸优化:



Train-Convexify-Train

  • Stage 1:从FedAvg训练的模型上提取eNTK特征。先让FedAvg训练上$T_1$轮,此时的模型权重为$\theta_{T_1}$。然后在每个客户端上计算eNTK的子采样特征:


  • Stage 2:从FedAvg训练的模型上提取eNTK特征。给定客户端$k$上的样本${(z_i^k, Y_i^k)}_{i=1}^{n_k}$,首先用一轮通信给所有客户端的eNTK进行归一化,然后使用SCAFFOLD算法解决线性回归问题。


实验结果

数据集采用的FMNIST,CIFAR-10,CIFAR-100。

设计了2种数据异构的分布情况:

  • 以参数$\alpha$影响的对称狄利克雷分布。$\alpha$越小,异质性越高。
  • 客户从类的固定子集中获得样本。$\# C=2 $表示每个客户端都有来自2个类的样本。

实验选择了FedAvg,FedProx,SCAFFOLD作为baseline,baseline共训练200轮,而作者的方法在前期训练和后期训练中各训练100轮。

实验结果如下:




可以看到TCT的提升效果非常惊人!!而且在高数据异构的的情况下TCT的性能并没有下降很多,狄利克雷系数$\alpha$从0.5降到0.001时,准确率才下降了1.5%不到。而且在每个client只含有1个类的极端情况下,TCT的表现依然很好。

下面这张图是FedAvg、SCAFFOLD和TCT在CIFAR100数据集上面对不同数据异构情形的表现:




FedAvg和SCAFFOLD在数据异质性较高的情况下精度越来越低,但TCT表现得非常稳定。

通信效率实验



在所有3个数据集上,作者的算法始终优于全批量梯度下降,且算法中的局部步骤可以加速收敛。如算法在CIFAR100上的表现,算法在20轮内完成收敛。因此,算法可以很大程度上利用局部计算,提高通信效率。

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 联邦学习论文阅读:神经正切核+联邦学习


Make Everyday Count