联邦学习+拆分学习组成的新范式:SplitFed


论文名称:SplitFed: When Federated Learning Meets Split Learning(AAAI 2022)
论文网址:https://arxiv.org/pdf/2004.12088.pdf

这篇论文是澳大利亚联邦科学与工业研究组织发表在AAAI2022上的论文,比较神奇的是这篇论文在2022年4月份就已经挂到arxiv上了。这篇论文也第一次让我了解了拆分学习(SplitFed这种范式。)

Split Learning是个什么东西?

Split Learning我个人的理解就是一种FL的新范式。他把深度学习网络$W$分为2部分,$W_C$和$W_S$,分别称为客户端网络和服务器端网络。存留数据的客户端只需要提交到客户端网络,服务器端只提交到服务器端网络。这种网络的正向传播和反向传播的流程如下:
正向传播:客户端利用原始数据将网络训练到切割层,并将切割层的activation(破碎数据,smashed data)发送给服务器。然后服务器使用从客户端接收到的破碎数据对剩余层进行训练。一次完成一次正向传播。
反向传播:服务器进行向上的反向传播,传播至切割层,然后将破碎数据的梯度发送给客户端。使用梯度,客户端在剩余的网络(网络的第一层)上执行它的反向传播。以此完成客户端和服务器端之间的反向传播的一次传递。

重复正向传播和反向传播,直到网络的所有可用客户机的训练并达到收敛。

论文方法



上图中,右侧体现了客户端和服务器如何执行网络训练。

与传统的Split Learning不同,所有客户端都在clinet-side模型上并行的执行前向传播,然后将smashed data传递给主服务器。

假设服务器有足够的计算资源,在其服务器端模型上与每个客户端的smashed data一起并行处理正向传播和反向传播。

main-server将破碎数据的梯度发送回各自的客户端进行反向传播。服务器通过在每个客户机的破碎数据上反向传播期间计算的梯度的加权平均来更新其模型(其实就是执行了FedAvg)。

在客户端,每个客户端接收到smashed data的梯度后,对其客户端本地模型进行反向传播,并计算其梯度。并且使用DP(差分隐私)将这些梯度设置为私有,将其发送到联邦服务器(Fed Server)。Fed Server再执行客户端本地更新的FedAvg,并将其发送回所有参与的客户端。
算法代码:



实验分析

作者的网络结构选择是根据每个客户端的硬件性能不同来择取的。使用的网络体系结构包括但不限于:LeNet,AlexNet,ResNet等。
使用的数据集包括:HAM10000、MNIST、FMNIST、CIFAR10。

作者设定的每个网络的Split Layer:LeNet第二层(2DMaxPool层之后)、第二层AlexNet第二层(2DMaxPool层之后)、VGG16第四层(2D池层之后)和ResNet18第三层(2D批规范化层之后)。





局限性

我认为这个算法一些设置还是比较启发式,比如每个网络的Split Layer定在第几层?应该都是作者自己一点点试出来的。

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 联邦学习+拆分学习组成的新范式:SplitFed


Make Everyday Count