提升神经网络架构搜索稳定性,UCLA提出新型NAS算法

本文涉及的产品
文件存储 NAS,50GB 3个月
简介: 可微网络架构搜索能够大幅缩短搜索时间,但是稳定性不足。为此,UCLA 基于随机平滑(random smoothing)和对抗训练(adversarial training),提出新型 NAS 算法。


可微网络架构搜索(DARTS)能够大幅缩短搜索时间,但是其稳定性受到质疑。随着搜索进行,DARTS 生成的网络架构性能会逐渐变差。最终生成的结构甚至全是跳过连接(skip connection),没有任何卷积操作。在 ICML 2020 中,UCLA 基于随机平滑(random smoothing)和对抗训练(adversarial training),提出了两种正则化方法,大幅提升了可微架构搜索算法的鲁棒性。

微信图片_20211204091124.jpg



近期,可微架构搜索算法将 NAS 搜索时间缩短至数天,因而备受关注。然而,其稳定生成高性能神经网络的能力受到广泛质疑。许多研究者发现随着搜索进行,DARTS 生成的网络架构反而越来越差,最终甚至会完全变为跳过连接(skip connection)。为了支持梯度下降,DARTS 对于搜索空间做了连续化近似,并始终在优化一组连续可微的框架权重 A。但是在生成最终框架时,需要将这个权重离散化。

本研究作者观察到这组连续框架权重 A 在验证集上的损失函数非常不平滑,DARTS 总是会收敛到一个非常尖锐的区域。因此对于 A 轻微的扰动都会让验证集性能大幅下降,更不用说最终的离散化过程了。这样尖锐的损失函数还会损害搜索算法在架构空间中的探索能力。

于是,本文作者提出了新型 NAS 框架 SmoothDARTS(SDARTS),使得 A 在验证集上的损失函数变得十分平滑。

该工作的主要贡献包括:

  • 提出 SDARTS,大幅提升了可微架构搜索算法的鲁棒性和泛化性。SDARTS 在搜索时优化 A 整个邻域的网络权重,而不仅仅像传统可微 NAS 那样只基于当前这一组参数。第一种方法优化邻域内损失函数的期望,没有提升搜索时间却非常有效。第二种方法基于整个邻域内的最差损失函数(worst-case loss),取得了更强的稳定性和搜索性能。
  • 在数学上,尖锐的损失函数意味着其 Hessian 矩阵范数非常大。作者发现随着搜索进行,这一范数极速扩大,导致了 DARTS 的不稳定性。而本文提出的两种框架都有数学保障可以一直降低 Hessian 范数,这也在理论上解释了其有效性。
  • 最后,本文提出的方法可以广泛应用于各种可微架构算法。在各种数据集和搜索空间上,作者发现 SDARTS 可以一贯地取得性能提升。


微信图片_20211204091130.jpg


具体方法

传统 DARTS 使用一组连续的框架权重 A,但是 A 最终却要被投射到离散空间以获得最终架构。这一步离散化会导致网络性能大幅下降,一个高性能的连续框架并不意味着能生成一个高性能的离散框架。因此,尽管 DARTS 可以始终减少连续框架在验证集上的损失函数,投射后的损失函数通常非常不稳定,甚至会突变得非常大。

因此作者希望最终获得的连续框架在大幅扰动,例如离散化的情况下,仍然能保持高性能。这也意味了损失函数需要尽可能平滑,并保持很小的 Hessian 范数。因此本文提出在搜索过程中即对 A 进行扰动,这便会让搜索算法关注在平滑区域。

微信图片_20211204091139.jpg


SDARTS-RS 基本随机平滑(random smoothing),优化 A 邻域内损失函数的期望。该研究在均匀分布中采样了随机噪声,并在对网络权重 w 进行优化前加到连续框架权重 A 之上。

这一方法非常简单,只增加了一行代码并且不增加计算量,可作者发现其有效地平滑了在验证集上的损失函数。

SDARTS-ADV 基于对抗训练(adversarial training),优化邻域内最差的损失函数,这一方法希望最终搜索到连续框架权重 A 可以抵御最强的攻击,包括生成最终架构的离散化过程。在这里,我们使用 PGD (projected gradient descent)迭代获得当前最强扰动。

微信图片_20211204091152.jpg


整个优化过程遵循可微 NAS 的通用范式,交替优化框架权重 A 和网络权重 w。

微信图片_20211204091154.jpg


理论分析

微信图片_20211204091210.jpg


对 SDARTS-RS 的目标函数进行泰勒展开,作者发现这在搜索过程中,Hessian 矩阵的 trace norm 也在被一直减小。如果 Hessian 矩阵近似 PSD,那么近似于一直在减小 Hessian 的正特征值。相似地,在通常的范数选择下(2 范数和无穷范数),SDARTS-ADV 目标函数中第二项近似于被 Hessian 范数 bound 住。因此它也可以随着搜索降低范数。

微信图片_20211204091221.jpg


这些理论分析进一步解释了为何 SDARTS 可以获得平滑的损失函数,在扰动下保持鲁棒性与泛化性。

实验结果

NAS-Benchmark-1Shot1 实验
这个 benchmark 含有 3 个不同大小的搜索空间,并且可以直接获得架构的性能,不需要任何训练过程。这也使本文可以跟踪搜索算法任意时刻得到架构的精确度,并比较他们的稳定性。
如图 4 所示,DARTS 随着搜索进行生成的框架不断变差,甚至在最后的性能直接突变得很差。近期提出的一些新的改进算法,例如 NASP 与 PC-DARTS 也难以始终保持高稳定性。与之相比,SDARTS-RS 与 SDARTS-ADV 大幅提升了搜索稳定性。得益于平滑的损失函数,该研究提出的两种方法还具有更强的探索能力,甚至在搜索迭代了 80 轮之后仍能持续发现精度更高的架构。
另外,作者还在图 5 中跟踪了 Hessian 范数的变化情况,所有 baseline 方法的范数都扩大了 10 倍之多,而本文提出的方法一直在降低该范数,这与上文的理论分析一致。

微信图片_20211204091227.jpg


CIFAR-10 实验
作者在通用的基于 cell 的空间上进行搜索,这里需要对获得架构进行 retrain 以获得其精度。值得注意的是,除了 DARTS,本文提出的方法可以普遍适用于可微 NAS 下的许多方法,例如 PC-DARTS 和 P-DARTS。如表 1 所示,作者将原本 DARTS 的 test error 从 3.00% 减少至 2.61%,将 PC-DARTS 从 2.57% 减少至 2.49%,将 P-DARTS 从 2.50% 减少至 2.48%。搜索结果的方差也由于稳定性的提升而减小。

微信图片_20211204091234.jpg


ImageNet 实验
为了测试在大数据集上的性能,作者将搜索的架构迁移到 ImageNet 上。在表 2 中,作者获得了 24.2% 的 top1 test error,超过了所有相比较的方法。

微信图片_20211204091239.jpg


与其他正则项方法比较

作者还在另外 4 个搜索空间 S1-S4 和 3 个数据集上做实验。这四个空间与 CIFAR-10 上的搜索空间类似,只是包含了更少的操作,例如 S2 只包含 3x3 卷积和跳过连接,S4 只包括 3x3 卷积和噪声。在这些简化的空间上实验能进一步验证 SDARTS 的有效性。

微信图片_20211204091244.jpg


如表 4 所示,SDARTS 在这 12 个任务中的 9 个中包揽了前两名,SDARTS-ADV 分别平均超过 DARTS、R-DARTS (L2)、DARTS-ES、R-DARTS (DP) 和 PC-DARTS 31.1%、11.5%、11.4%、10.9% 和 5.3%。

相关实践学习
基于ECS和NAS搭建个人网盘
本场景主要介绍如何基于ECS和NAS快速搭建个人网盘。
阿里云文件存储 NAS 使用教程
阿里云文件存储(Network Attached Storage,简称NAS)是面向阿里云ECS实例、HPC和Docker的文件存储服务,提供标准的文件访问协议,用户无需对现有应用做任何修改,即可使用具备无限容量及性能扩展、单一命名空间、多共享、高可靠和高可用等特性的分布式文件系统。 产品详情:https://www.aliyun.com/product/nas
相关文章
|
2月前
|
机器学习/深度学习 算法 TensorFlow
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
动物识别系统。本项目以Python作为主要编程语言,并基于TensorFlow搭建ResNet50卷积神经网络算法模型,通过收集4种常见的动物图像数据集(猫、狗、鸡、马)然后进行模型训练,得到一个识别精度较高的模型文件,然后保存为本地格式的H5格式文件。再基于Django开发Web网页端操作界面,实现用户上传一张动物图片,识别其名称。
89 1
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
|
2月前
|
机器学习/深度学习 人工智能 算法
深度学习入门:理解神经网络与反向传播算法
【9月更文挑战第20天】本文将深入浅出地介绍深度学习中的基石—神经网络,以及背后的魔法—反向传播算法。我们将通过直观的例子和简单的数学公式,带你领略这一技术的魅力。无论你是编程新手,还是有一定基础的开发者,这篇文章都将为你打开深度学习的大门,让你对神经网络的工作原理有一个清晰的认识。
|
14天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
55 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
植物病害识别系统Python+卷积神经网络算法+图像识别+人工智能项目+深度学习项目+计算机课设项目+Django网页界面
植物病害识别系统。本系统使用Python作为主要编程语言,通过收集水稻常见的四种叶片病害图片('细菌性叶枯病', '稻瘟病', '褐斑病', '稻瘟条纹病毒病')作为后面模型训练用到的数据集。然后使用TensorFlow搭建卷积神经网络算法模型,并进行多轮迭代训练,最后得到一个识别精度较高的算法模型,然后将其保存为h5格式的本地模型文件。再使用Django搭建Web网页平台操作界面,实现用户上传一张测试图片识别其名称。
116 22
植物病害识别系统Python+卷积神经网络算法+图像识别+人工智能项目+深度学习项目+计算机课设项目+Django网页界面
|
2月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
104 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
27天前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
50 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
|
29天前
|
机器学习/深度学习 算法 5G
基于BP神经网络的CoSaMP信道估计算法matlab性能仿真,对比LS,OMP,MOMP,CoSaMP
本文介绍了基于Matlab 2022a的几种信道估计算法仿真,包括LS、OMP、NOMP、CoSaMP及改进的BP神经网络CoSaMP算法。各算法针对毫米波MIMO信道进行了性能评估,通过对比不同信噪比下的均方误差(MSE),展示了各自的优势与局限性。其中,BP神经网络改进的CoSaMP算法在低信噪比条件下表现尤为突出,能够有效提高信道估计精度。
35 2
|
2月前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
交通标志识别系统。本系统使用Python作为主要编程语言,在交通标志图像识别功能实现中,基于TensorFlow搭建卷积神经网络算法模型,通过对收集到的58种常见的交通标志图像作为数据集,进行迭代训练最后得到一个识别精度较高的模型文件,然后保存为本地的h5格式文件。再使用Django开发Web网页端操作界面,实现用户上传一张交通标志图片,识别其名称。
96 6
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
|
2月前
|
机器学习/深度学习 人工智能 算法
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集("体育类", "财经类", "房产类", "家居类", "教育类", "科技类", "时尚类", "时政类", "游戏类", "娱乐类"),然后基于TensorFlow搭建CNN卷积神经网络算法模型。通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型,并保存为本地的h5格式。然后使用Django开发Web网页端操作界面,实现用户上传一段文本识别其所属的类别。
87 1
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
21天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化卷积神经网络(Bayes-CNN)的多因子数据分类识别算法matlab仿真
本项目展示了贝叶斯优化在CNN中的应用,包括优化过程、训练与识别效果对比,以及标准CNN的识别结果。使用Matlab2022a开发,提供完整代码及视频教程。贝叶斯优化通过构建代理模型指导超参数优化,显著提升模型性能,适用于复杂数据分类任务。
下一篇
无影云桌面