0. 引言
联邦学习(Federated Learning)允许用户在将数据保留在本地端不共享的前提下形成一个联合体训练得到全局模型,从而有效解决数据隐私和安全保护问题。同时,还可以有效应用联合体各方用户所掌握的标注数据,解决标注数据缺乏的问题。在联邦学习架构的每一轮学习过程中,中央服务器在当前全部客户端中选定一些客户端子集并将全局模型下发给这些客户端子集。然后,这些客户端子集在本地运行随机梯度下降(SGD)等优化处理步骤后生成本地模型。最后,客户端子集将本地模型发送回中央服务器。反复执行训练过程直到模型收敛,生成最终的全局模型。
目前,联邦学习的应用面临四个主要问题:通信开销问题、隐私保护问题、客户端无状态问题和客户端中数据非独立同分布问题。其中,通信开销问题主要是由客户端和中央服务器之间经由网络连接和传输数据(模型、参数)所造成的。隐私保护问题主要是指经由网络传输时用户信息、模型信息的隐私和安全保护问题。客户端无状态问题是指一般情况下在多轮训练期间,没有一个客户端会参与超过一次的训练。客户端中数据非独立同分布问题则是指不同客户端,特别是边缘设备,所收集到的数据通常不是独立的,也不具备相同的数据分布特性。本文重点关注通信开销问题的最新研究进展。通信带宽是联邦学习的主要瓶颈,因为大量的设备都将其本地更新发送到中央服务器中。因此,对于一个通信效率高的联邦学习算法来说,这种更新必须以压缩和不频繁的方式发送。
在实际场景中,特别是在所需的全局模型规模较大的情况下,网络带宽限制和工作节点数量可能会加剧联邦学习的通信瓶颈,从而造成客户端设备掉队 / 退出的问题。在经典的联邦学习框架下,系统会将一些网络带宽受限或访问受限的客户端排除在训练的轮次之外,即不将全局模型发送给这些客户端进行本地优化。这种简单的处理方式会大大影响这些客户端所提供的服务,进而影响用户的使用体验。
针对通信开销问题最简单直接的解决方案是以牺牲模型准确度为代价、在联邦学习的整体框架中仅训练占用通信空间较小的低容量模型。从这个角度出发,来自 Google 的研究人员 Koneˇcný et al. 提出了一种降低上行通信成本(Client-to-Server FL Communication)的方法 [1]:客户端只将本地计算得到的模型更新传递到中央服务器,而不是完整的本地模型。很显然,这种方法虽然能够降低通信成本,但是并不能满足复杂场景下业务应用需要。在此基础上,来自同一研究小组的 Caldas et al. 提出了一种能够有效降低下行通信成本(Server to Client Communication)同时与已有降低上行通信成本方法无缝集成的方法[2],具体包括在服务器到客户端的全局模型上使用有损压缩,以及允许用户在全局模型的较小子集上高效完成本地训练的、能够同时减少客户端到服务器的通信成本和本地计算的 Federated Dropout。
与这两篇文章的思路类似,Rothchild el al. 提出了一种使用 Count Sketch 对客户端模型更新进行压缩处理的方法 FedSGD [3]。由于 Count Sketch 是线性的,可以通过 Sketch 计算动量和误差累积,从而将动量和误差累积的计算任务从客户端转移到中央服务器,克服了稀疏客户端参与更新的问题,同时保持高压缩率和良好的收敛性。Reisizadeh et al. 提出了一种周期平均和量化的处理方法 FedPAQ [6],量化处理本身也是压缩的一种方式。FedPAQ 允许网络中的客户端在与中央服务器同步之前执行本地训练,仅将活跃客户端的更新发送回中央服务器,且发回的仅为本地信息的量化版本。与上述从压缩角度出发的工作较为不同的是,Hamer el al. 提出了一种主要解决下行通信成本问题的集成方法:FedBoost [4]。集成方法是机器学习中的一种通用技术,用于组合多个基本预测因子或专家来创建一个更精确的模型。FedBoost 主要通过学习一组预先训练好的基本预测因子(Base Predictors)实现联邦集成(Federated Ensembles)。此外,Malinovsky el al. 提出了一种高效通信的分布式定点优化方法(Fixed-point optimization)[5],从解决优化问题或寻找凸凹函数的鞍点的角度出发限制客户端本地计算,从而解决联邦学习通信开销瓶颈问题。
1. 解决通信开销问题的研究进展
1.1 通过压缩方法解决通信开销问题
通过压缩处理减少联邦学习框架中上行、下行传递的数据量是最直接的解决通信开销问题的方法。我们首先来看一看这一类方法的研究进展情况。
1.1.1 模型更新上传方法 [1]
联邦学习的目标是从存储在大量客户端的数据中学习包含在真实矩阵 W 中参数的模型。在第 t 轮训练过程中,中央服务器将当前模型 W_t 分发给总共 n_t 个客户端的子集 S_t。这些客户端子集根据其本地数据独立训练并更新本地模型。具体介绍,第 i 个客户端的更新过程如下式所示:
这些更新可以是在客户端上计算得到的单个梯度值,也可以是使用更复杂的计算方式得到的结果,例如,在客户端的本地数据集上执行多个随机梯度下降(Stochastic Gradient Descent,SGD)处理。选定的客户端会将更新发送回中央服务器,中央服务器通过聚合所有客户端更新来计算得到全局模型:
其中,η_t 表示中央服务器中的学习速率。
作者介绍两种部分模型更新上传至中央服务器的方法。值得注意的是,本文只关注了客户端 - 中央服务器这一段的上行通信成本。第一种方法为结构更新(Structure Update)。令更新(H_t)^i 为预先定义的结构。本文具体考虑两种结构:低秩(Low-Rank)和随机掩模(Random Mask)。在低秩情况下,令(H_t)^i=(A_t)^i ・(B_t)^i。在后续计算过程中,随机生成(A_t)^i 并在本地训练过程中将其考虑为一个常数,只优化(B_t)^i。在实际实现中,可以将(A_t)^i 压缩成随机种子的形式,客户端只需要向中央服务器发送训练后的(B_t)^i。在每一轮次的更新中为每个客户端重新生成(A_t)^i 矩阵。在随机掩模的情况下,令(H_t)^i 为一个遵循预定义的随机稀疏模式(即随机掩模)的稀疏矩阵。在每一轮次的更新中每个客户端独立生成新的掩模。与低秩方法类似,稀疏掩模可以通过一个随机种子生成,因此只需要将(H_t)^i 的非零项值与种子一起发送。
第二种方法为草图更新(Sketched Update)。首先,在客户端完整的计算模型更新(H_t)^i。在发送至中央服务器之前,以有损压缩进行编码处理。中央服务器收到编码后解码再进行聚合处理。生成 Sketch 的方法包括:(1)下采样(Subsampling),采样更新的平均值是真实平均值的无偏估计值 E[H^_t]=H_t。与随机掩模的结构化更新类似,下采样掩码在每轮中也是在每个客户端独立随机采样的,并且掩码本身可以存储为同步种子。(2)概率量化(Probabilistic quantization),即将模型的权重(weights)量化处理。对于更新(H_t)^i,令 h=(h_1, ..., h_(d1x2d))=vec((H_t)^i),h_max=max_j(h_j),h_min=min_j(h_j)。h 的压缩更新为:
对于 4 个字节的浮点数,这种方法实现了 32 倍的压缩。对于一个比特的量化处理来说,首先均匀区分 [h_min, h_max] 为 2^b 个间隔。h_i 落在 h’和 h’’限定的区间内。量化操作将上述方程的 h_min 和 h_max 分别替换为 h’和 h’’。设定参数 b 以平衡准确度和通信成本,并通过随机化旋转改进量化。当尺度在不同维度上近似相等时,上述 1 比特和多比特量化方法效果最好。当 max=1,min=-1,且大部分值为 0 时,1 比特量化将会导致较大误差,一般可以通过在量化前对 h 进行随机旋转(h 乘以随机正交矩阵)来解决这个问题。在解码阶段,中央服务器需要在聚合所有更新之前执行反向旋转。
作者在实验中对比的是结构更新和草图更新两种方法的模型效果,如图 1 CIFAR 库和图 2 REDDIT 库中的结果。CIFAR 库中使用的模型有 9 个卷积层,其中第一层和最后一层的参数明显少于其他层。在压缩处理过程中只压缩内部 7 层,每个层都具有相同的参数。图 1 中使用关键字「mode」表示这种方法。对于低秩更新,「mode=25%」表示更新的秩被设置为全层变换的秩的 1/4,对于随机掩模(Random Mask)或草图绘制(Sketching),「mode=25%」的意思是对 25% 以外的所有参数进行置零处理。由图 1,结构化的随机掩模方法效果更优。
图 1. 在 CIFAR 数据库中结构化随机掩码更新和无量化 Sketched 更新的对比结果。
基于 REDDIT 数据库,作者训练了一个 LSTM 单词预测模型。该模型被训练成给定当前单词和前一时间传递的状态向量来预测下一个单词。为了减少更新的规模,除一些内存消耗小于 0.01% 的微小变量(例如 bias)外对所有模型变量进行 Sketching 处理。图 2 中使用 AccuracyTop1 进行评估,即模型赋予最高概率的单词是正确的预测结果。
图 2. Sketching 比较,在 Reddit 数据上训练一个 LSTM 模型,每轮随机抽样 50 个客户端。
1.1.2 有损压缩方法 [2]
在 1.1.1 节中提出的模型更新传递方法也是一种有损压缩策略,它主要解决的是客户端 - 中央服务器的上行通信开销问题,本节中的方法主要聚焦中央服务器 - 客户端的下行通信开销,同时还能与处理上行通信开销的方法进行无缝集成。
图 3. 有损压缩方法整体思路。
该有损压缩方法的整体思路如图 3 所示。(1)通过 Federated Dropout 构造子模型,(2)对生成的对象进行无损压缩,来减小通信模型的大小。然后将这个压缩模型发送给客户端,(3)客户端使用本地数据对其进行解压缩和训练,(4)压缩最终的本地更新。将本地更新发送回中央服务器,(5)中央服务器执行解压缩,(6)中央服务器聚合生成全局模型。
在方法第(2)步中所提到的「无损压缩」借鉴的是节 1.1.1 中阐述的方法,包括基本变换、下采样、概率量化等。只不过本文将这些压缩应用于中央服务器到客户端的交换中,这也意味着,这种方法不能利用在中央服务器端通过无损压缩平均噪声解压缩处理(Averaging the Noisy Decompressions)的改进。关于下采样和概率量化处理本节不再详述。关于基本变换的处理,作者 Koneˇcný et al. 在文献 [1] 中并未详述。实际上文献 [1] 中的基本变换为随机 Hadamard 变换(HD),目的是均匀地将向量信息分布在各个维度上。而本文中除随机 Hadamard 变换外,还考虑了 Kashin 表征方法(K)从而尽可能在每个维度上传播向量的信息
。
为了进一步降低通信成本,本文引入 Federated Dropout 的有损压缩方式,即每个客户端不需要局部训练全局模型的更新,只是训练一个更小的子模型的更新。在传统的 Dropout 方法中,使用一个随机的二进制掩码乘以隐藏单元,以便在每次训练经由网络传输时丢弃一部分期望的神经元。因为每个过程中的掩码都会发生变化,所以每一个过程都需要计算相对于不同子模型的梯度。这些子模型可以有不同的大小(结构),具体取决于每层中丢弃多少个神经元。在本文处理中,为了满足节省联邦学习通信开销的要求,在每个全连接层上都将固定数量的激活归零,这样使所有子模型都具有相同的简化体系结构。如图 4 所示。
图 4. 应用于两个全连接层的 Federated Dropout。
中央服务器可以将必要的值映射到这个简化的架构中,这意味着只将必要的系数传输到客户端,重新打包成更小的密集矩阵。客户端(可能完全不知道原始模型的架构)训练其子模型并发送其更新,然后中央服务器将其映射回全局模型。对于卷积层来说,将激活归零不会实现任何空间的节省,因此,作者去掉了一定比例的滤波器。
除了节省服务器到客户端的通信开销外,Federated Dropout 还带来了另外两个好处。首先,客户端到中央服务器更新的规模也减少了。其次,本地训练过程现在只需要运行较少的梯度更新。当前所有的矩阵乘法都是较小维度的(相对于全连接层来说),或者是只需要使用较少的滤波器(对于卷积层来说)。因此,使用 Federated Dropout 进一步降低了联邦学习中的本地计算成本。
在实验部分,作者首先分别基于 CIFAR-10 和 EMNIST 库验证了压缩处理对本文方法的影响。由图 5 的实验结果,(1)对于每一个模型,都能找到一组至少与基线方法(no compression)相匹敌的压缩参数设置;(2)Kashin 表征方法对量化处理最为有用;(3)在中央服务器到客户端的通信过程中进行下采样处理的效果并不好。
图 5. 有损压缩参数对 CIFAR-10 和 EMNIST 的影响。
然后,作者通过实验验证 Federated Dropout 方法对全局模型准确度的影响。图 6 给出了模型在不同的 Federated Dropout 率下的收敛情况。在本文模型中,将每层保留的神经元(或卷积层的过滤器)的百分比定义为 Federated Dropout 率。将每个实验重复 10 次,最终报告 10 次重复实验结果的平均值。由图 6,对于每个模型来说都存在小于 1.0 的 Federated Dropout 率与基线方法效果相当。在某些情况下,甚至可以提高模型的最终准确度。此外,Federated Dropout 率为 0.75 的情况下不同模型都可以获得较好的效果。这种 Dropout 率相当于丢弃掉 25% 的行和全连接层的权重矩阵的列(即,相当于减少了 43%),并减少相应百分比的滤波器。如果 Federated Dropout 率过高,往往会减慢模型的收敛速度,尽管较高的 Federated Dropout 率有时可能获得更高的准确度。
图 6. Federated Dropout 的结果,改变每层保留的神经元百分比。
最后,作者进行了实验以验证将两种策略(有损压缩和 Federated Dropout)与客户端到服务器压缩方案(节 1.1.1 中文献 [1] 提出的方法)相结合时本文方法的效果。作者评估了本文模型在 3 种不同的压缩方案(Aggressive、Moderate 和 Conservative)和 4 种不同的 Federated Dropout 率(0.5、0.625、0.75 和 0.875)的表现。图 7 给出了 CIFAR-10 和 EMNIST 中的实验结果。重复每个实验 5 次,最终报告 5 次重复实验结果的平均值。对于所有三个模型,除了 Aggressive 压缩方案外,Federated Dropout 率为 0.75 时在所有压缩方案下都不会造成模型准确度下降。对于 MNIST 和 EMNIST,中央服务器到客户端的通信成本节省了 14 倍,客户端到中央服务器的通信成本节省了 28 倍,本地计算量减少了 1.7 倍,所有这些通信开销成本的节约都不会降低最终全局模型的准确性(有时甚至还能够提高准确性)。
图 7. 压缩和 Federated Dropout 对 CIFAR-10 和 EMNIST 的影响。
1.1.3 Count Sketch 压缩处理 [3]
这篇文章中提出的 FetchSGD 使用 Count Sketch 来压缩模型更新,然后利用 Sketches 的可合并性来将不同客户端的模型更新进行合并。FetchSGD 设计中的一个关键问题是,由于 Count Sketch 是线性的,动量和误差累积都可以在 Count Sketch 中进行。这使得该方法能够将动量和误差累积从客户端转移到中央服务器中,从而在克服稀疏客户端参与挑战的同时,确保高压缩率和良好的收敛性。FetchSGD 的完整方法见图 8:(1)在客户端本地计算梯度,(2)将梯度的 sketches 发送到中央服务器中,中央服务器聚合梯度(3)sketches(4)动量和误差累积(5),(6)提取近似的 top-k 值,(7)中央服务器将稀疏值更新到参与下一轮训练的客户端设备中。
图 8. FetchSGD 完整方法图示。
在 FetchSGD 的每次迭代中,第 i 个参与的客户端使用部分(或全部)本地数据计算随机梯度 (g_i)^t,然后使用称为 Count Sketch 的数据结构压缩(g_i)^t。每个客户端将 Sketch S((g_i)^t) 作为其模型更新发送到聚合器(中央服务器)。Count Sketch 是一种随机的数据结构,它可以通过将向量多次随机投影到低维空间来压缩向量,以便后续近似地恢复高幅值(High-Magnitude Elements)数据。如下式:
由于 Count Sketch 具备线性特性,中央服务器可以在给定 S((g_i)^t)的情况下,计算出每个小批量梯度 g^t 的 sketch:
此外,Count Sketch 的另一个有用特性是,对于 sketching 操作符 S(),有一个对应的解压操作符 U()返回原始向量的无偏估计,从而实现对向量高幅值元素的近似:
相比之下,其他有偏梯度压缩方法在压缩梯度时会给客户端带来偏差,因此客户端本身必须保持单独的误差累积向量。这在联邦学习中是很难保证的,这是由于客户端仅能参与一次更新,这样就没有机会在下一轮中重新引入错误。从另一个角度看,由于 S()是线性的,并且误差累积只包含线性操作,因此在 S_e 的中央服务器上进行误差累积相当于在每个客户端上进行误差累积,并将结果 Sketch 上传到中央服务器。更进一步,我们注意到动量也只包含线性操作,因此动量可以等价地在客户端或中央服务器上执行。推广上述方程可以得到:
完整的 FetchSGD 计算流程如下:
本文实验主要在小型本地数据集和 Non-IID 数据上完成,因为作者认为这是联邦学习中一个重要且相对未解决的问题。经典的梯度稀疏化方法是将每个客户端的局部 top-k 梯度元素聚合在一起训练全局模型的,当各个客户端的本地数据集非常小或相互之间差异非常大时,这种方法在近似全局梯度的真正 top-k 梯度元素时的表现就会非常差。在这种情况下,与经典方法相比 FetchSGD 有一个关键的优势:FetchSGD 的压缩算子是线性的。在 FetchSGD 中「只使用具有 N 个数据的单个客户端执行一个步骤相当于使用 N 个客户端执行一个步骤」,因此,每个客户端只贡献一次数据,小型客户端本地数据集不会带来任何问题。独立同分布的问题也可以通过随机选择客户端得到解决,FetchSGD 将参与训练的各个客户端的数据聚合起来,因此可以得到更完整的数据分布样本。
实验中,作者以上传和下载的总字节数表征相对于未压缩的 SGD 所实现的压缩效果。这些数据中没有考虑到的一个重要因素是,在 FedAvg 中,客户端必须在参与之前立即下载一个完整的模型,因为每个模型的权重在每一轮中都会得到更新。相比之下,局部 Top-k 和 FetchSGD 每轮只更新有限数量的参数,因此未参与的客户端可以相对更新当前模型,从而减少了必须在参与之前立即下载的新参数的数量。这使得本地 Top-k 和 FetchSGD 的上传压缩比下载压缩更重要。下载压缩对于这三种方法(FedAvg、Top-k、FetchSGD)也不那么重要,因为目前边缘设备的互联网连接的下载速度往往远远高于上传速度。作者给出了整体压缩(包括上传和下载)的结果,在图 9 中把这些图分成单独的上传和下载部分进行展示。FetchSGD 在不同的数据集和上传、下载和整体压缩的任务中,表现都较优。
图 9. CIFAR10(左)和 CIFAR100(右)的上传(顶部)、下载(中间)和整体(底部)压缩效果比对。为了提高可读性,每个图只显示该图中显示的压缩类型的运行的帕
累托边界。
1.1.4 周期平均和量化的压缩处理方法 [6]
这篇文章介绍了一种周期平均和量化处理的联邦学习方法(a Communication-efficient Federated learning algorithm with Periodic Averaging and Quantization, FedPAQ)。其中的量化处理也可以看作是一种压缩方式。
在经典联邦学习框架中,为了利用客户端节点上所有可用的数据样本,参与训练的客户端在每次训练迭代中通过中央服务器同步其模型,因此,客户端和中央服务器之间要进行多次通信,从而导致网络上的通信争用造成较大通信开销。本文提出的周期平均和量化处理方法令客户端进行本地更新并定期通过中央服务器进行同步。一旦某客户端节点从中央服务器中获取到更新的模型,客户端节点每隔 τ 次本地 SGD 就向中央服务器发送更新信息以更新聚合模型。这种周期平均方案减少了中央服务器和客户端之间的通信次数,从而降低了训练模型的总体通信成本。如果在客户端中运行 T 次 SGD 迭代,则客户端需要与中央服务器进行 K=T/τ 轮通信,从而将总通信成本降低为原成本的 1/τ。这就是「周期平均」的处理思路。
从 K=T/τ 的计算公式可以直观看出,选择较大的 τ 值可以减少固定迭代次数 T 的通信次数,进而降低通信成本。但是这种降低是以牺牲模型准确度为代价的。增大 τ,会增加系统的噪声,进而客户端中的局部模型会逐渐收敛到局部最优解,而不是全局最优解。因此,作者考虑运行更多次迭代 T 来使模型达到特定的准确度。事实上,我们需要解决的一个关键问题是找到最优 τ,以使整个过程通信成本最小化。
在联邦学习网络中,通常有大量的设备(如智能电话)与中央服务器(基站)进行通信。但是,基站本身的下载带宽有限,因此只有少数设备能够同时将其消息上载到基站。由于这一限制,从客户端设备发送的消息将在中央服务器基站中进行流水线式的串行传输,这导致训练速度大大减慢。另一方面,如果让所有的客户端设备都参与到整个训练过程中,将会造成巨大的、昂贵的网络通信开销。此外,在实际应用中,并不是所有的客户端都在每一轮训练的过程中发挥作用的。有很多因素决定客户端是否参与当前的训练过程:设备需在当前状态下处于中央服务器基站通信可达的范围内、客户端设备在当前状态下为空闲可用状态、客户端设备通电且联网等等。
考虑到上述因素,FedPAQ 假设,在总共 n 个客户端设备中每轮训练中只有 r 个节点(r ≤ n)可用,且这 r 个可用设备在网络上随机且均匀地分布。在第 k 个训练周期内,中央服务器将其当前模型 x_k 发送给本轮选定的参与训练的 S_k 个客户端子集中的 r 个客户端节点,r 个客户端节点在该子集的总共 n 个客户端节点之间随机均匀分布。另一方面,联邦学习中的设备上行链路带宽有限,这使得从客户端到中央服务器的通信缓慢且昂贵,这也是前面几节中各种压缩方法所考虑的主要问题。本文所提出的方法是在传输信息中使用量化算子,通过交换量化更新来降低网络通信开销。
在进行 τ 轮本地 SGD 更新后,每个客户端 i∈S_k 中拥有本地模型 (x_k,τ)^i,其中 x_k 为最新从中央服务器中获取的全局模型。每个客户端对 x_k 和(x_k,τ)^i 的差值应用量化算子 Q(),并将量化结果 Q((x_k,τ)^i-x_k) 发送至中央服务器。中央服务器接收到量化结果后进行反量化解码处理,并基于处理结果生成新的全局模型 x_k+1。本文使用的量化算子为:
完整的 FedPAQ 方法如下:
总的来说,FedPAQ 通过使用三个模块来降低通信负载:周期平均、部分客户端参与和量化处理。然而,这种通信减少带来了收敛准确度降低的问题,因此需要更多次的训练迭代。作者在原文中还进行了 FedPAQ 的收敛性分析,并给出了 FedPAQ 强凸和非凸损失函数的近似最优理论证明。我们在这里不再详述。
作者在实验中对比了通信开销和收敛性的 tradeoff 的结果分析。实验以总的训练时间为代价目标,包括通讯时间和计算时间。首先,定义网络带宽(Bandwidth,BW),每轮的通讯时间为上传的总比特数除以 BW。每轮的总比特数计算为 r ・ | Q(p,s)|,其中 | Q(p,s)| 表示根据具有 s 级(s levels)的特定量化算子对 p 维向量进行量化编码所需的比特数。在本文实验中使用的量化算子中,假定它需要 pF 位来表示长度为 p 的未量化向量,其中 F 为 32 位。
然后,利用梯度计算时间的位移指数模型(Shifted-Exponential Model)确定计算时间。假设对于任何客户端,计算一个周期内的 τ 次迭代和批量大小为 B 的梯度需要确定位移 τ ・ B ・ scale^(-1) ,其中 shift 和 scale 分别是位移指数模型的位移和尺度参数,实验中 B 确定为 10。每轮的总计算时间就是 r 个贡献客户端节点中最大的本地计算时间。最后,计算通信 - 计算比为:
在图 10 中,前四个图展示了在 MNIST 数据集('0'和'8'位)上,T=100 次迭代的正则化逻辑回归问题的训练时间。联邦学习网络中共有 n=50 个客户端节点,每个节点加载 200 个样本。设置 C_comm/C_comp=100/1 来捕获通信瓶颈。第一列图中曲线显示了在中央服务器上每轮的训练时间与所获得的训练损失之间的关系。第二列图中曲线显示了参与更新的客户端数量 n 的影响。第三列论证了周期长度 τ 在通信 - 计算 tradeoff 中的作用。最后一列图比较了 FedPAQ 与其他两个基线方法 FedAvg 和 QSGD 的训练时间。后四个图为 CIFAR-10 数据库中神经网络的实验结果。具体图例与 MNIST 中结果相同。
图 10. 训练损失与训练时间:MNIST 的 Logistic 回归分析(上),CIFAR-10 的神经网络结果(下)。
1.2 其它处理方法
1.2.1 集成方法 [4]
针对联邦学习的通信开销问题,这篇文章提出了利用集成方法(Ensemble method)的思路。集成方法是机器学习中的一种通用技术,用于组合多个基本预测因子(Base predictors)或专家(Experts)来创建一个更精确的模型。作者认为,联邦学习中的通信开销问题是由每轮从中央服务器发送到客户端(下行)和从客户端发送到中央服务器(上行)的参数数量引起的。在每轮训练过程中,中央服务器将当前模型的迭代状态发送给全部参与的客户端,直接将集成方法应用于这种联邦学习框架中会由于每轮都需要传递预测值而导致通信爆炸。本文提出的 FedBoost 能够在降低通信成本的同时实现计算加速、收敛保证和隐私保护。这种方法可以通过联邦学习使用客户端设备上的数据来训练一个原本可能超过客户端的通信带宽和存储容量的模型。此外,FedBoost 能够同时降低中央服务器到客户端(下行)和客户端到中央服务器(上行)的通信成本。
集成方法通过联邦学习的框架在中央服务器端只需要学习混合权重,所需要经由客户端发送给中央服务器的数据量非常小。因此,作者认为在集成方法中客户端到中央服务器(上行)的通信成本可以忽略不计。本文提出了标准(Standard)和任务不可知(Agnostic)的联邦学习集成方法,以解决中央服务器到客户端(下行)的通信瓶颈问题。
首先介绍标准联邦学习集成方法。给定一组预先训练的基本预测因子或假设:H≜{h_1, ..., h_q}。在标准集成方法中,将全部的假设都发送给每个参与的客户端。然而,在实践中,由于中央服务器和客户端之间的通信带宽以及客户端的内存和计算能力的限制,这种处理方式是不可行的。作者提出了一种抽样方法,只将其中一部分假设发送给客户端。这虽然可以降低通信复杂度,但同时会带来整体梯度偏差的问题以及集成收敛性的不确定性问题。经典的集成方法为:
设 C 为每轮发送给客户端的最大基本预测因子数,即 C 能够表征通信效率。中央服务器端的目标函数是学习一组针对预先训练的基本估计量 h_k 的系数 α:
FedBoost 的完整算法流程如下:
在每轮训练中,FedBoost 在中央服务器上抽取两个子集:一个预训练假设子集,其中每一个子集以概率 γ_k,t 抽取得到(用 H_t 表示);N 个客户端的随机子集(用 S_t 表示)。定义以下 Bernoulli 指标:
其中,L_k(α) 为标准联邦学习中域 k 的经验损失。针对 L(α,λ)的优化问题为一个两人博弈问题,找到最小化目标函数和对手的 α,同时使用λ最大化目标函数。最终目标是找到给定 α_opt 的极小极大博弈的均衡,它使得对于混合权重λ_opt 的损失最小化。l 为凸函数,可以使用一般镜像下降(generic mirror descent)或其他基于梯度的算法来优化解决这个问题。
作者提出了任务无关的 AFLBoost 方法优化上述目标函数。AFLBoost 可以看作是 FedBoost 和针对任务无关损失函数的随机镜像下降算法 [7] 的结合。AFLBoost 的详细算法流程如下:
作者在实验中证明了 FedBoost 在不同通信成本下对密度估计(Density estimation)任务的有效性。具体包括三种方法:(1)无通信效率处理(无采样):γ_k,t=1。(2)均匀抽样:γ_k,t=C/q。(3)加权随机抽样:γ_k,t ∝ α_k,t·C。
作者首先创建了一个 p=100 的合成数据集,其中每个 h_k 是单个元素上的点质分布(Point-mass distribution),初始化每个 α_k 为 1/p,混合权重λ遵循幂律分布(Power law distribution)。实验结果见图 11。由图 11 左,加权抽样的性能优于均匀抽样,两种方法的损失都在稳步下降。由图 11 中,在通信预算为 64 的情况下,不考虑通信效率均匀抽样和加权抽样的性能与 FedBoost 相同。此外,作者还在经典 Shakespeare 数据集中进行实验。如图 11 右,加权随机抽样比均匀抽样的表现更好,在这个库中加权抽样的表现甚至优于 FedBoost,具有更好的收敛性能。
图 11. 实验对比图。左:合成数据集的损失曲线比较;中:合成数据集中取样方法的比较;右:Shakespeare 联邦学习数据集的损失曲线比较。
1.2.2 分布式不动点优化方法 [5]
针对联邦学习的通信开销问题,一些研究人员的解决思路是利用客户端的本地计算。也就是说,在通信和模型聚合处理之前,在每个客户端设备中进行更多的本地计算,从而减少获得全局有意义的解决方案所需的通信总轮数。在这种思路下,研究人员集中考虑了一些局部梯度下降算法以改进本地计算的效果。本文就是这种思路的工作之一。但是本文的工作并不局限于通过梯度下降来最小化目标函数,作者引入了计算 M 个操作算子平均值的不动点的方法。实际上,大多数迭代方法都属于不动点方法(Fix-Point method),其目的是寻找某个算子的不动点。
为了从不动点方法角度进行分析,首先介绍经典的分布式不动点优化模型。令分布式系统中共包括 M 个并行计算节点(客户端)。每个节点中处理的变量可看作是欧氏空间中的向量。令Τ_i 表示欧氏空间的操作算子,平均算子为:
本文方法的核心是找到Τ的不动点,即找到向量 x* 满足 Τ(x*)=x*。可以通过在每个节点重复应用Τ_i,同时进行平均化处理以达到共识来最终获得目标解 x*。本文作者考虑,经过多次迭代后,每个客户端节点将其变量同步传递到远程中央服务器中。然后中央服务器计算所接收到的向量的平均值并将其广播到所有节点。
本文考虑两种不动点优化策略:第一种策略是对于每个客户端计算节点,迭代执行若干次某个操作序列(称之为局部步骤,local steps)。第二种策略是减少通信步骤的数量,即仅以一定的较低概率到中央服务器中共享信息,并且只在中间过程进行局部计算。
将 T 定义为欧氏空间中的一个操作算子,令 Fix(T) 表示 T 的不动点集合,对于欧氏空间中的每个 x 和 y,如果 T 满足下式,则称 T 是χ-Lipschitz 连续的:
此外,如果 T 是 1-Lipschitz 连续和χ- 收缩的,则称 T 是非扩张的。如果 T 是收缩的,则它具有存在且唯一的不动点。对于任意 α∈(0,1],如果对于一些非扩张算子 T’ ,如果存在 T=αT’+Id,则称 T 是 α 平均的。如果 T 是 1/2 平均的,则称 T 是坚决不扩张的。
令(t_n)_n∈N 表示通信过程中的整数序列。在每次迭代过程中,操作算子Τ_i 应用于节点 i,并利用参数λ进行松弛。对于一定数量的迭代,M 个计算节点全部将其本地向量传输给中央服务器的主节点,主节点计算平均值并将平均值广播给全部节点。全部节点在新的一轮迭代开始时拥有相同的变量 (x^)^k。该算法是局部梯度下降法(Federated Averaging)的推广。
我们把一系列本地迭代称为一个 epoch,然后求取平均值。也就是说,第 n 个 epoch 是指数 k+1=t_(n-1)+1,...,t_n 的迭代序列。假设在两个聚合计算平均值步骤之间的每个 epoch 的迭代次数由某个整数 H 限定。则对于每个 n≥1,有 1≤t_n - t_(n-1) ≤H。具体,第一种策略的详细算法流程如下:
接下来,作者提出了第二种策略。第一种策略中的本地处理步骤可以看作是两个通信步骤之间的内环(Inner Loop),将内环用概率聚合(Probabilistic Aggregation)来代替,即可得到第二种策略。在以下意义上,它是通信 - 有效的:在第一种策略中,通信轮数除以 H(或是非均匀情况下 t_n - t_(n-1)的平均值),而在第二种策略中该值乘以概率 p≤1。第二种策略的详细算法流程如下:
在原文中,作者对两种策略算法的收敛性能进行了充分的论证,我们在这里不再详述。
作者选择经典分类问题的逻辑回归进行实验。相应的目标函数如下:
其中,a_i∈R^d,b_i∈{-1,+1} 为数据样本。使用 LIBSVM 库中的「a9a」和「a4a」数据集,并将 k 设为 L/n,其中 n 是数据集的大小,L 是ᐁf 第一部分的 Lipschitz 常数,且没有经过正则化处理。
本文实验中考虑梯度下降(Gradient Descent,GD)作为操作算子。也就是说,我们考虑最小化有限和问题:
使用第一种策略的算法 1 和使用第二种策略的算法 2 的实验结果见图 12 和图 13。参数 H 和λ的值越大,初始收敛速度越快,但邻域半径越大。就计算时间而言,该算法没有太大的优势,因为实验是在一台机器上进行的,通信时间可以忽略不计。但是在通信速度较慢的分布式环境中,本文的算法就有明显的优势。我们也可以观察到图中实验结果曲线没有出现振荡。因此,当只需要达到有限的准确度时,本文提出的本地方法具有明显的优势。
图 12. 具有梯度下降步长的算法 1 在均匀通信时间 t_n=nH 下的收敛性,(a)不同 H 值的通讯轮数,λ=0.5,(b)不同 H 值的计算时间,λ=0.5,(c)不同λ值的计算时间,H=4。
图 13. 梯度下降步长算法 2 的收敛性,λ=0.5,(a)梯度步长相同情况下,不同 p 值的通讯轮次数量,(b)梯度步长相同情况下,不同 p 值的计算时间,(c)梯度步长与 p 是成比例的,不同 p 值的通讯轮次数量。
2. 总结
我们在这篇文章重点关注了联邦学习框架中的通信开销研究进展。目前,大多数文章都从压缩的角度出发解决通信开销问题,这种方法的思路很直观:压缩后需要上行、下行传递的数据量就会减小,从而减轻通信开销。当然,压缩的方法有很多,例如有损压缩、提取 sketch、量化等等。此外,我们也分析了两篇非压缩思路的文章,作者分别使用了集成方法和加强本地计算的方法。在不同的文章中,作者对比和分析的实验指标各不相同,这说明目前还没有标准化、统一化、权威性的衡量联邦学习通信开销的指标,毕竟通信开销和计算效率是一对 tradeoff 的指标。单纯用通信时间或通信数据量去衡量方法的优劣并不客观。
目前,随着 5G 技术的发展,5G 网络中通信速率问题变得不再是问题。依托 5G,使用边缘设备的应用场景也越来越多,例如校园安全监控、明厨亮灶监控、移动执法等等。在这种情况下,是否能够缓解联邦学习中的通信开销问题,进而推动联邦学习更快的发展和应用?让我们拭目以待吧!
参考文献
[1] Jakub Koneˇcný, H Brendan McMahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon.Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1610.05492, 2016b. https://arxiv.org/pdf/1610.05492.pdf.[2] Caldas, S., Koneˇcny, J., McMahan, H. B., and Talwalkar, A. Expanding the reach of federated learning by reducingclient resource requirements. arXiv preprint arXiv:1812.07210, 2018.[3] Rothchild, D. , Panda, A. , Ullah, E. , Ivkin, N. , Stoica, I. , & Braverman, V. , et al. (2020). FetchSGD: communication-efficient federated learning with sketching, ICML 2020, http://arxiv.org/abs/2007.07682v1.[4] Jenny Hamer, Mehryar Mohri , Ananda Theertha Suresh, FedBoost: A Communication-Efficient Algorithm for Federated Learning,ICML 2020.[5] Malinovsky, G. , Kovalev, D. , Gasanov, E. , Condat, L. , & Richtarik, P. . (2020). From local sgd to local fixed point methods for federated learning, ICML 2020, https://arxiv.org/abs/2004.01442.[6] Reisizadeh A , Mokhtari A , Hassani H , et al. FedPAQ: A Communication-Efficient Federated Learning Method with Periodic Averaging and Quantization[J]. arXiv, 2019.https://arxiv.org/pdf/1909.13014.pdf.[7] Mohri, M., Sivek, G., and Suresh, A. T. Agnostic federated learning. In International Conference on Machine Learning, pp. 4615–4625, 2019.