手把手教你估算深度神经网络的最优学习率(附代码&教程)

简介:

学习率如何影响训练?

深度学习模型通常由随机梯度下降算法进行训练。随机梯度下降算法有许多变形:例如 Adam、RMSProp、Adagrad 等等。这些算法都需要你设置学习率。学习率决定了在一个小批量(mini-batch)中权重在梯度方向要移动多远。

如果学习率很低,训练会变得更加可靠,但是优化会耗费较长的时间,因为朝向损失函数最小值的每个步长很小。

如果学习率很高,训练可能根本不会收敛,甚至会发散。权重的改变量可能非常大,使得优化越过最小值,使得损失函数变得更糟。

2eda41ef20632039573ef5858eae32c03b9d97fe

学习率很小(上图)和学习率很大(下图)的梯度下降。来源:Cousera 机器学习课程(吴恩达)

训练应当从相对较大的学习率开始。这是因为在开始时,初始的随机权重远离最优值。在训练过程中,学习率应当下降,以允许细粒度的权重更新。

有很多方式可以为学习率设置初始值。一个简单的方案就是尝试一些不同的值,看看哪个值能够让损失函数最优,且不损失训练速度。我们可以从 0.1 这样的值开始,然后再指数下降学习率,比如 0.01,0.001 等等。当我们以一个很大的学习率开始训练时,在起初的几次迭代训练过程中损失函数可能不会改善,甚至会增大。当我们以一个较小的学习率进行训练时,损失函数的值会在最初的几次迭代中从某一时刻开始下降。这个学习率就是我们能用的最大值,任何更大的值都不能让训练收敛。不过,这个初始学习率也过大了:它不足以训练多个 epoch,因为随着时间的推移网络将需要更加细粒度的权重更新。因此,开始训练的合理学习率可能需要降低 1-2 个数量级。

一定有更好的方法

Leslie N. Smith 在 2015 年的论文「Cyclical Learning Rates for Training Neural Networks」的第 3.3 节,描述了一种为神经网络选择一系列学习率的强大方法。

诀窍就是从一个低学习率开始训练网络,并在每个批次中指数提高学习率。

8599914cf00435ec018a2cef4ef02eb7b50ef011

在每个小批量处理后提升学习率

为每批样本记录学习率和训练损失。然后,根据损失和学习率画图。典型情况如下:

254c1d38b9222e7129252e692238401ca9a210de

一开始,损失下降,然后训练过程开始发散

首先,学习率较低,损失函数值缓慢改善,然后训练加速,直到学习速度变得过高导致损失函数值增加:训练过程发散。

我们需要在图中找到一个损失函数值降低得最快的点。在这个例子中,当学习率在 0.001 和 0.01 之间,损失函数快速下降。

另一个方式是观察计算损失函数变化率(也就是损失函数关于迭代次数的导数),然后以学习率为 x 轴,以变化率为 y 轴画图。

5e3c14a94e3b0df18e8d1f17312cdc76ef4323e1

损失函数的变化率


上图看起来噪声太大,让我们使用简单移动平均线(SMA)来做平缓化处理。

0ccab93fcadb638b5ebece17b7f442ff4c18df0c

使用 SMA 平缓化处理后的损失函数变化率

这样看起来就好多了。在这个图中,我们需要找到最小值位置。看起来,它接近于学习率为 0.01 这个位置。

实现

Jeremy Howard 和他在 USF 数据研究所的团队开发了 fast.ai。这是一个基于 PyTorch 的高级抽象的深度学习库。fast.ai 是一个简单而强大的工具集,可以用于训练最先进的深度学习模型。Jeremy 在他最新的深度学习课程(http://www.fast.ai/)中使用了这个库。


fast.ai 提供了学习率搜索器的一个实现。你只需要写几行代码就能绘制模型的损失函数-学习率的图像(来自 GitHub:plot_loss.py):


# learn is an instance of 
Learner
class
 or one of derived classes like 
ConvLearner
learn.lr_find()
learn.sched.plot_lr()

库中并没有提供代码绘制损失函数变化率的图像,但计算起来非常简单(plot_change_loss.py):


def plot_loss_change(sched, sma=
1
, n_skip=
20
, y_lim=(-
0.01
,
0.01
)):

 
"""
 
Plots
 rate of change of the loss 
function
.
 
Parameters
:
 sched - learning rate scheduler, an instance of LR_Finder 
class
.
 sma - number of batches 
for
 simple moving average to smooth out the curve.
 n_skip - number of batches to skip on the left.
 y_lim - limits 
for
 the y axis.
 
"""

 derivatives = [
0
] * (sma + 
1
)
 
for
 i in range(
1
 + sma, len(learn.sched.lrs)):
 derivative = (learn.sched.losses[i] - learn.sched.losses[i - sma]) / sma
 derivatives.append(derivative)
 plt.ylabel(
"d/loss"
)
 plt.xlabel(
"learning rate (log scale)"
)
 plt.plot(learn.sched.lrs[n_skip:], derivatives[n_skip:])
 plt.xscale(
'log'
)
 plt.ylim(y_lim)
plot_loss_change(learn.sched, sma=
20
)

请注意:只在训练之前选择一次学习率是不够的。训练过程中,最优学习率会随着时间推移而下降。你可以定期重新运行相同的学习率搜索程序,以便在训练的稍后时间查找学习率。

使用其他库实现本方案

我还没有准备好将这种学习率搜索方法应用到诸如 Keras 等其他库中,但这应该不是什么难事。只需要做到:

  • 多次运行训练,每次只训练一个小批量;
  • 在每次分批训练之后通过乘以一个小的常数的方式增加学习率;
  • 当损失函数值高于先前观察到的最佳值时,停止程序。(例如,可以将终止条件设置为「当前损失 > *4 最佳损失」)

学习计划

选择学习率的初始值只是问题的一部分。另一个需要优化的是学习计划(learning schedule):如何在训练过程中改变学习率。传统的观点是,随着时间推移学习率要越来越低,而且有许多方法进行设置:例如损失函数停止改善时逐步进行学习率退火、指数学习率衰退、余弦退火等。

我上面引用的论文描述了一种循环改变学习率的新方法,它能提升卷积神经网络在各种图像分类任务上的性能表现。


原文发布时间为:2017-11-24

本文作者:Pavel Surmenok

本文来自云栖社区合作伙伴“数据派THU”,了解相关信息可以关注“数据派THU”微信公众号

相关文章
|
5月前
|
安全 网络协议 算法
Nmap网络扫描工具详细使用教程
Nmap 是一款强大的网络发现与安全审计工具,具备主机发现、端口扫描、服务识别、操作系统检测及脚本扩展等功能。它支持多种扫描技术,如 SYN 扫描、ARP 扫描和全端口扫描,并可通过内置脚本(NSE)进行漏洞检测与服务深度枚举。Nmap 还提供防火墙规避与流量伪装能力,适用于网络管理、渗透测试和安全研究。
866 1
|
6月前
|
机器学习/深度学习 算法 数据挖掘
没发论文的注意啦!重磅更新!GWO-BP-AdaBoost预测!灰狼优化、人工神经网络与AdaBoost集成学习算法预测研究(Matlab代码实现)
没发论文的注意啦!重磅更新!GWO-BP-AdaBoost预测!灰狼优化、人工神经网络与AdaBoost集成学习算法预测研究(Matlab代码实现)
214 0
|
7月前
|
JSON 监控 API
在线网络PING接口检测服务器连通状态免费API教程
接口盒子提供免费PING检测API,可测试域名或IP的连通性与响应速度,支持指定地域节点,适用于服务器运维和网络监控。
917 0
|
5月前
|
JavaScript Java 大数据
基于python的网络课程在线学习交流系统
本研究聚焦网络课程在线学习交流系统,从社会、技术、教育三方面探讨其发展背景与意义。系统借助Java、Spring Boot、MySQL、Vue等技术实现,融合云计算、大数据与人工智能,推动教育公平与教学模式创新,具有重要理论价值与实践意义。
|
10月前
|
数据采集 存储 监控
Python 原生爬虫教程:网络爬虫的基本概念和认知
网络爬虫是一种自动抓取互联网信息的程序,广泛应用于搜索引擎、数据采集、新闻聚合和价格监控等领域。其工作流程包括 URL 调度、HTTP 请求、页面下载、解析、数据存储及新 URL 发现。Python 因其丰富的库(如 requests、BeautifulSoup、Scrapy)和简洁语法成为爬虫开发的首选语言。然而,在使用爬虫时需注意法律与道德问题,例如遵守 robots.txt 规则、控制请求频率以及合法使用数据,以确保爬虫技术健康有序发展。
1410 31
|
10月前
|
域名解析 API PHP
VM虚拟机全版本网盘+免费本地网络穿透端口映射实时同步动态家庭IP教程
本文介绍了如何通过网络穿透技术让公网直接访问家庭电脑,充分发挥本地硬件性能。相比第三方服务受限于转发带宽,此方法利用自家宽带实现更高效率。文章详细讲解了端口映射教程,包括不同网络环境(仅光猫、光猫+路由器)下的设置步骤,并提供实时同步动态IP的两种方案:自建服务器或使用三方API接口。最后附上VM虚拟机全版本下载链接,便于用户在穿透后将服务运行于虚拟环境中,提升安全性与适用性。
646 7
|
监控 Linux PHP
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
462 20
|
12月前
|
人工智能 网络协议 IDE
使用通义灵码AI高效学习muduo网络库开发指南
Muduo 是一个基于 C++11 的高性能网络库,支持多线程和事件驱动,适用于构建高效的服务器和应用程序。它提供 TCP/IP 协议支持、异步非阻塞 I/O、定时器、异步日志等功能,并具备跨平台特性。通过 Git 克隆 muduo 仓库并切换至 C++17 分支可开始使用。借助 AI 工具如 Deepseak-v3,用户可以更便捷地学习和理解 Muduo 的核心模块及编写测试用例,提升开发效率。
|
前端开发 小程序 Java
uniapp-网络数据请求全教程
这篇文档介绍了如何在uni-app项目中使用第三方包发起网络请求
924 3
|
网络协议 安全 NoSQL
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!