KDD 2019 | 不用反向传播就能训练DL模型,ADMM效果可超梯度下降

简介: 随机梯度下降 (SGD) 是深度学习的标准算法,但是它存在着梯度消失和病态条件等问题。本文探索与反向传播(BP)完全不同的方向来优化深度学习模型,即非梯度优化算法,提出了「反向前向的交替方向乘子法」的深度模型优化算法,即 dlADMM。该方法解决了随机梯度下降存在的问题,在多个标准数据集上达到并超过梯度下降算法的效果,并且第一次给出了全局收敛的数学证明。同时增强了算法的可扩展性,为解决一些当前重要的瓶颈问题提供了全新视角,比如复杂不可导问题以及非常深的神经网络的高性能计算问题。目前,该论文已被数据挖掘领域顶会 KDD 2019 接收。

 论文:ADMM for Efficient Deep Learning with Global Convergence

微信图片_20211201210838.jpg


本文提出了一种基于交替方向乘子法的深度学习优化算法 dlADMM。该方法可以避免随机梯度下降算法的梯度消失和病态条件等问题,弥补了此前工作的不足。此外,该研究提出了先后向再前向的迭代次序加快了算法的收敛速度,并且对于大部分子问题采用二次近似的方式进行求解,避免了矩阵求逆的耗时操作。在基准数据集的实验结果表明,dlADMM 击败了大部分现有的优化算法,进一步证明了它的有效性和高效。


背景


深度学习已经在机器学习的各个领域受到广泛的应用,因为 深度学习模型可以表征非线性特征的多层嵌套组合,所以相比传统的机器学习模型,它的表达性更丰富。由于深度学习通常用在大数据的应用场景中,所以需要一种优化算法可以在有限的时间内得到一个可用的解。随机梯度下降算法 (SGD) 和它的许多变体 如 ADAM 是深度学习领域广泛使用的优化算法,但是它存在着如梯度消失 (gradient vanishing) 和病态条件 (poor conditioning) 等问题;另一方面, 作为近年非常热门的优化框架,交替方向乘子算法 (ADMM) 可以解决 SGD 存在的问题: ADMM 的基本原理是把一个复杂的复合目标函数分解成若干个简单的子函数求解,这样不需要用链式法则求复合函数的导数,从而避免了梯度消失的问题,另外 ADMM 对输入不敏感,所以不存在病态条件的问题 [1]。除此之外 ADMM 还有诸多优点:它 可以解决非光滑函数的优化问题;它在很多大规模的深度学习应用中展现了强大的可扩展性 (scalability) 等等。


经典反向传播算法 (BP)


一个典型的神经网络问题如下所示:

微信图片_20211201210834.jpg


其中 W_l 和 b_l 是第 l 层的权重和截距, f_l (∙) 和 R(∙) 分别是激活函数和损失函数, L 是层数。经典的反向传播算法分前向传送梯度和后向更新参数两部分:每层的梯度按照链式法则向前传输,然后根据损失函数反向参数更新如下:

微信图片_20211201211115.jpg


尽管反向传播非常实用,然而它存在一些问题,因为前面层级的梯度必须等后面层级的梯度算完,BP 的速度比较慢;因为 BP 需要保存前向传播的激活值,所以显存占用比较高;最常见的问题就是对于深度神经网络存在梯度消失, 这是因为根据链式法则微信图片_20211201211027.jpg

如果微信图片_20211201211030.jpg, 那么随着层数的增加,梯度信号发生衰减直至消失。连提出 BP 的 Geoffrey Hinton 都对它充满了质疑。主要是因为反向传播机制实在是不像大脑。Hinton 在 2017 年的时候提出胶囊理论尝试替代标准神经网络。但是只要反向传播的机制不改变,梯度消失的问题就不会解决。因此目前有许多研究者尝试采用各种机制训练神经网络,早在 2016 年,Gavin Taylor[1] 等人就提出了 ADMM 的替代想法,ADMM 的原理如图 1 所示,一个神经网络按照不同的层被分解成若干个子问题,每个子问题可以并行求解,这样不需要求复合目标函数 R(∙) 的导数,解决了梯度下降的问题。按照 ADMM 的思路,神经网络问题被等价转换为如下问题 1:


问题 1:

微信图片_20211201210829.jpg


其中 a_l 是辅助变量。 

微信图片_20211201210825.jpg

图 1. 基于 ADMM 的深度神经网络子问题分解示意图


挑战


尽管 ADMM 具备很多优点,但是把它应用在深度学习的问题中的效果较当前最优算法入 SGD 和 ADAM 还有很大差距,很多技术和理论问题仍亟待解决:1) 收敛慢。即使最简单的目标函数,通常 ADMM 需要很多次迭代才能达到最优解。2) 对于特征维度的三次时间复杂度。在 Taylor 等人的实验中,他们使用了超过 9000 核 CPU 来让 ADMM 训练了仅仅 300 个神经元 [1]。其中 ADMM 最耗时的地方在于求解逆矩阵, 它的时间复杂度大概在 O(n^3 ), 其中 n 是神经元的个数。3) 缺乏收敛保证。尽管很多实验证明了 ADMM 在深度学习中是收敛的,然而它的理论收敛行为依然未知。主要原因是因为神经网络是线性和非线性映射和组合体,因而是高度非凸优化问题。基于这些问题,最新的一期 KDD2019 论文提出了 ADMM 的改进版本 dlADMM,第一次使基于 ADMM 的算法在多个标准数据集上达到当前最佳效果,并且在收敛性理论证明得到重要突破。


dlADMM 相比 ADMM 的优势:


  1. 加快收敛。文章提出了一种新的迭代方式加强了训练参数的信息交换,从而加快了 dlADMM 的收敛过程。
  2. 加快运行速度。作者通过二次近似的技术避免了求解逆矩阵,把时间复杂度从 O(n^3 ) 降低到 O(n^2 ),即与梯度下降相同的复杂度。从而大幅提高 ADMM 的运行速度。
  3. 具备收敛保证。本文第一次证明了 dlADMM 可以全局收敛到问题的一个驻点(该点导数为 0)。


下面具体讨论提出算法的这些优势:
1). 加快收敛
直接用 ADMM 解问题 1 并不能保证收敛,所以作者把问题 1 放松成如下的问题 2。在问题 2 中,当ν→+∞ 时,问题 2 无线逼近 问题 1。


问题 2:



在问题 2 中,ν>0 是一个参数, z_l 是一个辅助变量。在解问题 2 的过程中,通过增大ν可以使其理论上无限逼近问题 1。
问题 2 的增广拉格朗日 [2] 形式如下:

微信图片_20211201210822.jpg

其中微信图片_20211201210819.jpg是 dlADMM 中的一个超参数。

为了加快收敛过程,作者提出了一种新的迭代方式:先后向更新再前向更新,如图 2 所示。具体来讲,参数从最后一层开始更新,然后向前更新直到第一层,接着参数从第一层开始向后更新直到最后最后一层。这样更新的好处在于最后一层的参数信息可以层层传递到第一层,而第一层的参数信息可以层层传递到最后一层,加强参数信息交换,从而帮助参数更快地收敛。 

微信图片_20211201210816.jpg

图 2. dlADMM 原理图
2). 加快运行速度
对于求解 dlADMM 产生的子问题,大部分都需要耗时的矩阵求逆操作。为此,作者使用了二次近似的技术,如图 3 所示。在每一次迭代的时候对目标函数做二次近似函数展开,由于变量的二次项是一个常数,因此不需要求解逆矩阵,从而提高了算法的运行效率。 

微信图片_20211201210804.jpg

图 3. 二次近似


3.) 收敛保证
作者证明了无论参数 (W,b,z,a) 如何初始化,当ρ 足够大的时候,dlADMM 全局收敛于问题 2 的一个驻点上。
具体来说,是基于如下两条假设:
a. 求解 z 的子问题存在显式解。b. F 是强制的 (coercive),R(∙) 是莱布尼茨可导 (Lipschitz differentiable)。
对于假设 a, 常用的激活函数如 Relu 和 Leaky Relu 满足条件;对于假设 b,常用的交叉熵和最小二乘损失函数都满足条件。
在此基础之上, 本文证明了三条收敛性质:

  1. image.gif是有界的,L_ρ是有下界的。
  2. L_ρ是单调下降的。
  3. L_ρ的次梯度趋向于 0。


同时文章也证明了 dlADMM 的收敛率是 o(1/k).


实验结果


该论文在两个基准数据集 MNIST 和 Fashion MNIST 上进行了实验。


1. 收敛验证


作者画出了当ρ=1 和 ρ=〖10〗^(-6) 的收敛曲线,验证了当 ρ 足够大的时候,dlADMM 是收敛的(Figure 2),反之,dlADMM 不停地振荡(Figure 3)。  

微信图片_20211201210725.jpg微信图片_20211201210722.jpg


2. 效果比较


作者把 dlADMM 和当前公认的算法进行了比较。比较的方法包括:a. 随机梯度下降 (SGD). b. 自适应性梯度算法 (Adagrad). c. 自适应性学习率算法 (Adadelta).d. 自适应动量估计 (Adam). e. 交替方向乘子算法 (ADMM).Figure 4 和 Figure 5 展示了所有算法在 MNIST 和 Fashion MNIST 的训练集和测试集的正确率,可以看到开始的时候 dlADMM 上升最快,并且在二十次迭代之内迅速达到非常高的精度并且击败所有算法。在 80 次迭代之后虽被 ADAM 反超,但是仍然优于其他算法。文中提出的改进版 dlADMM 显著提高了 ADMM 的表现,使之充分发挥 ADMM 在迭代初期进展迅速的优点,同时在后期保证较高收敛率。 

微信图片_20211201210718.jpg

微信图片_20211201210712.jpg

image.gif


3. 时间分析


文章作者分析了 dlADMM 的运行时间和数据集数量,神经元个数以及和 ρ的选取上面的关系。如 table 2 和 table 3 所示,当数据集数量, 神经元个数和ρ的值越大的时候, dlADMM 的运行时间越长。具体来说, 运行时间和神经元个数和数据集数量成线性关系。这个结果体现了 dlADMM 强大的可扩展性 ( scalability)。 

微信图片_20211201210710.jpg微信图片_20211201210648.jpg



本文提出的 dlADMM 代码已经公开, 欢迎使用。链接如下: https://github.com/xianggebenben/dlADMM。欢迎邮件联系 jwang40@gmu.edu 或者 lzhao9@gmu.edu。


文献:

[1]. Taylor, Gavin, et al. "Training neural networks without gradients: A scalable admm approach". International conference on machine learning. 2016.[2]. Boyd, Stephen, et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers." Foundations and Trends® in Machine learning 3.1 (2011): 1-122.


相关文章
|
2月前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
5月前
|
机器学习/深度学习 算法
BP反向传播神经网络的公式推导
BP反向传播神经网络的公式推导
35 1
|
机器学习/深度学习 存储 算法
【WOA-LSTM】基于WOA优化 LSTM神经网络预测研究(Matlab代码实现)
【WOA-LSTM】基于WOA优化 LSTM神经网络预测研究(Matlab代码实现)
159 0
|
机器学习/深度学习 人工智能 算法
【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧
MINE方法中主要使用了两种技术:互信息转为神经网络模型技术和使用对偶KL散度计算损失技术。最有价值的是这两种技术的思想,利用互信息转为神经网络模型技术,可应用到更多的提示结构中,同时损失函数也可以根据具体的任务而使用不同的分布度量算法。
476 0
|
机器学习/深度学习 Python
神经网络中的损失函数正则化和 Dropout 并手写代码实现
神经网络中的损失函数正则化和 Dropout 并手写代码实现
185 0
神经网络中的损失函数正则化和 Dropout 并手写代码实现
|
机器学习/深度学习 人工智能 算法
【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度
GAN的原理与条件变分自编码神经网络的原理一样。这种做法可以理解为给GAN增加一个条件,让网络学习图片分布时加入标签因素,这样可以按照标签的数值来生成指定的图片。
626 0
|
机器学习/深度学习 存储 人工智能
【Pytorch神经网络基础理论篇】 08 Softmax 回归 + 损失函数 + 图片分类数据集
【Pytorch神经网络基础理论篇】 08 Softmax 回归 + 损失函数 + 图片分类数据集
255 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 08 Softmax函数(处理分类问题)
oftmax函数本质也为激活函数,主要用于多分类问题,且要求分类互斥,分类器最后的输出单元需要Softmax 函数进行数值处理。
312 0
|
机器学习/深度学习 传感器 资源调度
【FNN回归预测】基于Jaya优化前馈神经网络FNN实现数据回归预测附Matlab代码
【FNN回归预测】基于Jaya优化前馈神经网络FNN实现数据回归预测附Matlab代码
|
机器学习/深度学习 传感器 人工智能
【FNN预测】基于Jaya优化JAYA前馈神经网络FNN研究附Matlab代码
【FNN预测】基于Jaya优化JAYA前馈神经网络FNN研究附Matlab代码
下一篇
无影云桌面