详解DQN训练技巧!带你回到深度强化学习「梦开始的地方」

简介: 详解DQN训练技巧!带你回到深度强化学习「梦开始的地方」
【新智元导读】DeepMind开始称霸强化学习的DQN算法,都有哪些训练技巧?

过去十多年里,DeepMind在人工智能的发展中绝对有着重要的地位,从AlphaGo, AlphaZero到AlphaStar,再到如今的AlphaFold 2,每次DeepMind发布新产品似乎都要彻底消灭该行业

 

围棋界天才少年柯洁都不再下传统围棋,跑去练习云顶之弈。弈一时,悟一世,切换赛道誓在新概念围棋夺生涯第九冠(bushi)。

 

 

DeepMind在围棋、星际争霸和德州扑克等取得的巨大成就,实际上都归功于DeepMind于2013年发布的DQN算法,也是深度学习和强化学习的首次成功结合。

 

Deep Q-Networks (DQN) 于 2013 年首次发布,仅将游戏的像素值作为网络的输入,成功在一套雅达利(Atari)游戏中超越之前的所有模型的得分,甚至有三个还超越了骨灰级玩家的得分。

 

论文链接:https://arxiv.org/pdf/1312.5602.pdf

 

智能体直接从经验中进行学习,并成功学习到有效的行动在当年是一项重大突破,也让通用人工智能AGI的重回大众视线:计算机在诸多任务中获得的智能,也许比人类更强!

 

不过DQN的训练并没有想象中那么简单!

 

Q-learning是什么?

 

Q-learning是强化学习(RL)的经典算法,简单来说,RL智能体与环境进行交互,如果采取的行动是「好」的,就会获得奖励,否则获得惩罚,强化学习算法的目标是最大化智能体获得的长期奖励总和。

 

 

在强化学习智能体和环境之间的交互循环中,每个时间步(timestep),智能体需要选择一个行动(action)来改变环境(environment)状态(state)。环境也提供一个奖励信号(reward signal)以表示智能体的行动是否有利。

 

处于一个特定的游戏状态或采取一个行动的未来奖励是不难估计的,难的是你的行动对环境的影响可能是不确定的,这也意味着你得到的奖励也是不确定的。尤其是在我们不知道环境的运行规则,或是在很遥远的未来且状态数很多的情况下,我们怎么能知道一个行动会带来什么奖励呢?

 

比如说,玩《超级马里奥》某一关时,最佳的行动可能是在第一帧跳跃,但如果奖励一直在关卡的最后阶段,要怎么才能知道这个行动的价值?

 

 

Q-learning采取的方法是学习一个行动-价值函数(action-value function),也被称为Q函数。

 

Q函数为每个(状态,行动)组合分配一个价值,用来表示在某一状态下采取某一行动时预期未来回报的估计,并且Q函数为所有状态都定义了一个价值。

 

在Q-learning中,智能体通过与环境互动和更新采取的(状态,行动)的Q值来学习Q-函数估计价值。在采取一个行动之后,用环境中新状态的Q值来更新所有Q值。重复迭代,最终可以估计出该状态的Q值,并根据这一估计采取行动。

 

 

一些简单的游戏通过这种方式可以估计出所有的(状态,行动)对的价值,但对于雅达利游戏来说,(状态,行动)的组合数量实在是太多了,想存储在一个简单的表格中基本是无法实现的。

 

比如说在打砖块游戏中,如果只用球拍和球,在一个300*800像素的屏幕上,状态的数量就达到了10的9次方到10的11次方,海量的状态空间情况下,引入深度神经网络就显得很必要了。

 

神经网络不好训

 

Q-Learning和神经网络的结合在理论上是非常强大的。Q-learning可以让智能体学习任何决策任务,而神经网络可以表示任何函数。如果成功训练,就会有大量的潜在应用场景得以实现,比如自动驾驶汽车、机器人技术等。

 

 

但要训练Q-learning和神经网络的组合是非常困难的。即使经过多次在不同状态下采取行动并获得奖励的迭代,有时性能也不会提高。常见的情况就是,智能体的性能在明显改进之后开始出现下降。

 

在DeepMind发布DQN论文后,这种情况仍然很常见。

 

 

Q-learning算法的每个更新步骤都是基于该步的经历,但是,如果每走一步就更新的话,算法会因为抽样误差(sampling error)而导致不稳定的更新,而抽样误差是由任意分布中抽取数据点导致的。

 

如果你在最近的数据点的序列上进行训练,那么你看到的数据肯定都是相似的,因为通常需要很多个时间步才能遍历到整个状态空间,所以你访问的下一个状态与你当前所处的状态基本上就算密切相关。样本之间的这种相关性会使得学习效率低下,而将它们打散后,通过打破相关性可以改善学习效果。

 

为了缓解这种情况,DeepMind在DQN算法中引入了一种新机制:经验重放(Experience Replay),其中经验指的是智能体在一个时间段内观察到的状态、行动、奖励和下一个状态。经验重放将每个时间段的状态、行动、奖励和后续状态存储在内存中,并在每个时间段从中随机选择一批。

 

对数据进行抽样训练,使每次更新使用的经验随机化,就可以打破数据点之间的关联性,能够降低更新的方差。由于每一步的经验都被用于许多权重的更新,这也意味着训练需要更少的数据。

 

 

在Q-Learning中,有三个使用Q函数的地方:

 

  1. 为了得到第一个状态的Q值
  2. 用于评估哪个后续状态的Q值最高,以选择一个行动
  3. 找到该后续状态的Q值  

 

将Q-learning与神经网络结合起来,如果直接将同一个网络用于这三个地方,也就意味着如果模型高估了一个状态的价值,那前面的状态也会被高估,因为Q-learning使用最大行动价值作为最大预期行动价值的估计,可能会导致学习到一个错误的Q-函数估计。

 

不过在学习过程中,数值估计不精确是很正常的,也就是说,高估是很常见的。

 

 

如果对Q值的高估在各个状态都是一致的,那这就不是一个问题。如果所有的Q值都有类似的变化,那么我们选择的行动也会是一样的。但从经验上看,实际运行通常不是这样的,也就意味着由近似的Q值产生的策略(policy)不一定会收敛到最佳策略。

 

解决高估问题的方法是使用Double DQN,也是DeepMind在2015年发表的另一篇论文中提出的。

 

论文链接:https://arxiv.org/pdf/1509.06461.pdf

 

Double DQN指的是模型拥有两个深度神经网络,模型使用正在训练的网络在与环境互动时进行行动选择,Q-函数估计更新使用后续状态的Q值,这就是第二个目标网络派上用场的地方。

 

目标网络通常是网络的一个旧版本,用来寻找具有后续状态的最大Q值的行动,而原始网络用来评估这个后续行动的Q值。通过将用于行动选择和行动评估的Q值解耦,就不太可能选择到高估的值了。

 

 

自此,训练DQN的坑基本都被填上了,不过强化学习后续还取得了其他重大进展,比如围棋领域的AlphaGo,星际争霸、德州扑克等领域都被攻克。

 

但一切都是自DQN发布之后,深度强化学习才进入春天,DQN也展现了其解决通用问题的潜力。

参考资料:https://blog.delta-academy.xyz/why-deepmind-dqn-hard-to-train

相关文章
|
文字识别 算法 数据挖掘
文本检测 DBNet
文本检测 DBNet
617 0
|
4月前
|
JSON API 开发者
闲鱼商品详情API数据解析(附代码)
闲鱼商品详情API(goodfish.item_get)支持通过商品ID获取标题、价格、描述等信息,适用于比价、推荐系统及市场分析。接口支持GET/POST请求,返回JSON格式数据,并提供Python调用示例,便于开发者快速集成。
|
5月前
|
机器学习/深度学习 存储 算法
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
本文系统研究了多智能体强化学习的算法性能与评估框架,选用井字棋和连珠四子作为基准环境,对比分析Q-learning、蒙特卡洛、Sarsa等表格方法在对抗场景中的表现。实验表明,表格方法在小规模状态空间(如井字棋)中可有效学习策略,但在大规模状态空间(如连珠四子)中因泛化能力不足而失效,揭示了向函数逼近技术演进的必要性。研究构建了标准化评估流程,明确了不同算法的适用边界,为理解强化学习的可扩展性问题提供了实证支持与理论参考。
290 0
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
|
机器学习/深度学习 语音技术 开发工具
【独家秘籍】揭秘!如何用阿里云TTS魔法般将文字瞬间变成天籁之音,让你的作品开口说话,震撼人心!
【8月更文挑战第15天】通过阿里云语音合成服务(TTS),开发者可将文本转为自然语音,适用于有声阅读、客服等场景。首先注册并获取AccessKey ID/Secret,然后安装阿里云Python SDK。使用示例代码设置语音参数(如发音人xiaoyun、引擎wavenet),发送请求并保存生成的MP3文件。注意正确认证及异常处理,以确保应用稳定可靠。
998 0
|
C语言 Perl
西门子S7-1200编程实例,电动机起保停控制梯形图如何编写?
本篇我们通过一个电动机起保停控制的实例,介绍S7-1200的使用方法,按下瞬时启动按钮I0.6,电动机Q0.0启动,按下瞬时停止按钮I0.7,电动机Q0.0停止。
西门子S7-1200编程实例,电动机起保停控制梯形图如何编写?
|
存储 关系型数据库 API
必看!淘宝商品详情数据接口调用,助力商城上货实战全流程(仅供参考)
本文介绍了一个实战案例,通过调用淘宝商品详情数据接口,实现商品信息的自动获取与上架至自建电商平台。主要步骤包括需求分析、技术选型、接口调用、数据存储、自动上货及定时更新,旨在提升工作效率,减少人工操作。
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
668 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
移动开发 JavaScript 前端开发
UniApp H5 跨域代理配置并使用(配置manifest.json、vue.config.js)
这篇文章介绍了在UniApp H5项目中处理跨域问题的两种方法:通过修改manifest.json文件配置h5设置,或在项目根目录创建vue.config.js文件进行代理配置,并提供了具体的配置代码示例。
UniApp H5 跨域代理配置并使用(配置manifest.json、vue.config.js)
WK
|
机器学习/深度学习 算法 PyTorch
如何计算损失函数关于参数的梯度
计算损失函数关于参数的梯度是深度学习优化的关键,涉及前向传播、损失计算、反向传播及参数更新等多个步骤。首先,输入数据经由模型各层前向传播生成预测结果;其次,利用损失函数评估预测与实际标签间的差距;再次,采用反向传播算法自输出层逐层向前计算梯度;过程中需考虑激活函数、输入数据及相邻层梯度影响。针对不同层类型,如线性层或非线性层(ReLU、Sigmoid),梯度计算方式各异。最终,借助梯度下降法或其他优化算法更新模型参数,直至满足特定停止条件。实际应用中还需解决梯度消失与爆炸问题,确保模型稳定训练。
WK
608 0
|
前端开发 API 数据库
Python网站开发指南:构建现代化、高效的Web应用
在当今数字化时代,网站已成为企业、组织以及个人展示自己的重要窗口。Python作为一种简洁、高效且易于学习的编程语言,被广泛运用于网站开发领域。本文将向您介绍如何使用Python进行网站开发,包括常用的Web框架、关键技术和最佳实践。