循环神经网络(RNN)模型与前向反向传播算法

简介:

在前面我们讲到了DNN,以及DNN的特例CNN的模型和前向反向传播算法,这些算法都是前向反馈的,模型的输出和模型本身没有关联关系。今天我们就讨论另一类输出和模型间有反馈的神经网络:循环神经网络(Recurrent Neural Networks ,以下简称RNN),它广泛的用于自然语言处理中的语音识别,手写书别以及机器翻译等领域。

1. RNN概述

    在前面讲到的DNN和CNN中,训练样本的输入和输出是比较的确定的。但是有一类问题DNN和CNN不好解决,就是训练样本输入是连续的序列,且序列的长短不一,比如基于时间的序列:一段段连续的语音,一段段连续的手写文字。这些序列比较长,且长度不一,比较难直接的拆分成一个个独立的样本来通过DNN/CNN进行训练。

    而对于这类问题,RNN则比较的擅长。那么RNN是怎么做到的呢?RNN假设我们的样本是基于序列的。比如是从序列索引1到序列索引 τ 的。对于这其中的任意序列索引号 t ,它对应的输入是对应的样本序列中的 x ( t ) 。而模型在序列索引号 t 位置的隐藏状态 h ( t ) ,则由 x ( t ) 和在 t 1 位置的隐藏状态 h ( t 1 ) 共同决定。在任意序列索引号 t ,我们也有对应的模型预测输出 o ( t ) 。通过预测输出 o ( t ) 和训练序列真实输出 y ( t ) ,以及损失函数 L ( t ) ,我们就可以用DNN类似的方法来训练模型,接着用来预测测试序列中的一些位置的输出。

    下面我们来看看RNN的模型。

2. RNN模型

    RNN模型有比较多的变种,这里介绍最主流的RNN模型结构如下:

    上图中左边是RNN模型没有按时间展开的图,如果按时间序列展开,则是上图中的右边部分。我们重点观察右边部分的图。

    这幅图描述了在序列索引号 t 附近RNN的模型。其中:

    1) x ( t ) 代表在序列索引号 t 时训练样本的输入。同样的, x ( t 1 ) x ( t + 1 ) 代表在序列索引号 t 1 t + 1 时训练样本的输入。

    2) h ( t ) 代表在序列索引号 t 时模型的隐藏状态。 h ( t ) x ( t ) h ( t 1 ) 共同决定。

    3) o ( t ) 代表在序列索引号 t 时模型的输出。 o ( t ) 只由模型当前的隐藏状态 h ( t ) 决定。

    4) L ( t ) 代表在序列索引号 t 时模型的损失函数。

    5) y ( t ) 代表在序列索引号 t 时训练样本序列的真实输出。

    6) U , W , V 这三个矩阵是我们的模型的线性关系参数,它在整个RNN网络中是共享的,这点和DNN很不相同。 也正因为是共享了,它体现了RNN的模型的“循环反馈”的思想。  

3. RNN前向传播算法

    有了上面的模型,RNN的前向传播算法就很容易得到了。

    对于任意一个序列索引号 t ,我们隐藏状态 h ( t ) x ( t ) h ( t 1 ) 得到:

h ( t ) = σ ( z ( t ) ) = σ ( U x ( t ) + W h ( t 1 ) + b )

    其中 σ 为RNN的激活函数,一般为 t a n h b 为线性关系的偏倚。

    序列索引号 t 时模型的输出 o ( t ) 的表达式比较简单:

o ( t ) = V h ( t ) + c

    在最终在序列索引号 t 时我们的预测输出为:

y ^ ( t ) = σ ( o ( t ) )

    通常由于RNN是识别类的分类模型,所以上面这个激活函数一般是softmax。

    通过损失函数 L ( t ) ,比如对数似然损失函数,我们可以量化模型在当前位置的损失,即 y ^ ( t ) y ( t ) 的差距。

4. RNN反向传播算法推导

    有了RNN前向传播算法的基础,就容易推导出RNN反向传播算法的流程了。RNN反向传播算法的思路和DNN是一样的,即通过梯度下降法一轮轮的迭代,得到合适的RNN模型参数 U , W , V , b , c 。由于我们是基于时间反向传播,所以RNN的反向传播有时也叫做BPTT(back-propagation through time)。当然这里的BPTT和DNN也有很大的不同点,即这里所有的 U , W , V , b , c 在序列的各个位置是共享的,反向传播时我们更新的是相同的参数。

    为了简化描述,这里的损失函数我们为对数损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数。

    对于RNN,由于我们在序列的每个位置都有损失函数,因此最终的损失 L 为:

L = t = 1 τ L ( t )

    其中 V , c , 的梯度计算是比较简单的:

L c = t = 1 τ L ( t ) c = t = 1 τ L ( t ) o ( t ) o ( t ) c = t = 1 τ y ^ ( t ) y ( t )
L V = t = 1 τ L ( t ) V = t = 1 τ L ( t ) o ( t ) o ( t ) V = t = 1 τ ( y ^ ( t ) y ( t ) ) ( h ( t ) ) T

    但是 W , U , b 的梯度计算就比较的复杂了。从RNN的模型可以看出,在反向传播时,在在某一序列位置t的梯度损失由当前位置的输出对应的梯度损失和序列索引位置 t + 1 时的梯度损失两部分共同决定。对于 W 在某一序列位置t的梯度损失需要反向传播一步步的计算。我们定义序列索引 t 位置的隐藏状态的梯度为:

δ ( t ) = L h ( t )

    这样我们可以像DNN一样从 δ ( t + 1 ) 递推 δ ( t )  。

δ ( t ) = L o ( t ) o ( t ) h ( t ) + L h ( t + 1 ) h ( t + 1 ) h ( t ) = V T ( y ^ ( t ) y ( t ) ) + W T δ ( t + 1 ) d i a g ( 1 ( h ( t + 1 ) ) 2 )

    对于 δ ( τ ) ,由于它的后面没有其他的序列索引了,因此有:

δ ( τ ) = L o ( τ ) o ( τ ) h ( τ ) = V T ( y ^ ( τ ) y ( τ ) )

    有了 δ ( t ) ,计算 W , U , b 就容易了,这里给出 W , U , b 的梯度计算表达式:

L W = t = 1 τ L h ( t ) h ( t ) W = t = 1 τ d i a g ( 1 ( h ( t ) ) 2 ) δ ( t ) ( h ( t 1 ) ) T
L b = t = 1 τ L h ( t ) h ( t ) b = t = 1 τ d i a g ( 1 ( h ( t ) ) 2 ) δ ( t )
L U = t = 1 τ L h ( t ) h ( t ) U = t = 1 τ d i a g ( 1 ( h ( t ) ) 2 ) δ ( t ) ( x ( t ) ) T

    除了梯度表达式不同,RNN的反向传播算法和DNN区别不大,因此这里就不再重复总结了。

5. RNN小结

    上面总结了通用的RNN模型和前向反向传播算法。当然,有些RNN模型会有些不同,自然前向反向传播的公式会有些不一样,但是原理基本类似。

    RNN虽然理论上可以很漂亮的解决序列数据的训练,但是它也像DNN一样有梯度消失时的问题,当序列很长的时候问题尤其严重。因此,上面的RNN模型一般不能直接用于应用领域。在语音识别,手写书别以及机器翻译等NLP领域实际应用比较广泛的是基于RNN模型的一个特例LSTM,下一篇我们就来讨论LSTM模型。


本文转自刘建平Pinard博客园博客,原文链接:http://www.cnblogs.com/pinard/p/6509630.html,如需转载请自行联系原作者


相关文章
|
3月前
|
传感器 机器学习/深度学习 算法
【UASNs、AUV】无人机自主水下传感网络中遗传算法的路径规划问题研究(Matlab代码实现)
【UASNs、AUV】无人机自主水下传感网络中遗传算法的路径规划问题研究(Matlab代码实现)
123 0
|
2月前
|
存储 机器学习/深度学习 监控
网络管理监控软件的 C# 区间树性能阈值查询算法
针对网络管理监控软件的高效区间查询需求,本文提出基于区间树的优化方案。传统线性遍历效率低,10万条数据查询超800ms,难以满足实时性要求。区间树以平衡二叉搜索树结构,结合节点最大值剪枝策略,将查询复杂度从O(N)降至O(logN+K),显著提升性能。通过C#实现,支持按指标类型分组建树、增量插入与多维度联合查询,在10万记录下查询耗时仅约2.8ms,内存占用降低35%。测试表明,该方案有效解决高负载场景下的响应延迟问题,助力管理员快速定位异常设备,提升运维效率与系统稳定性。
222 4
|
2月前
|
机器学习/深度学习 算法
采用蚁群算法对BP神经网络进行优化
使用蚁群算法来优化BP神经网络的权重和偏置,克服传统BP算法容易陷入局部极小值、收敛速度慢、对初始权重敏感等问题。
322 5
|
2月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
3月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
236 2
|
2月前
|
机器学习/深度学习 人工智能 算法
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
152 0
|
3月前
|
算法 数据挖掘 区块链
基于遗传算法的多式联运车辆路径网络优优化研究(Matlab代码实现)
基于遗传算法的多式联运车辆路径网络优优化研究(Matlab代码实现)
127 2
|
机器学习/深度学习
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
【从零开始学习深度学习】33.语言模型的计算方式及循环神经网络RNN简介
|
机器学习/深度学习 数据采集 存储
时间序列预测新突破:深入解析循环神经网络(RNN)在金融数据分析中的应用
【10月更文挑战第7天】时间序列预测是数据科学领域的一个重要课题,特别是在金融行业中。准确的时间序列预测能够帮助投资者做出更明智的决策,比如股票价格预测、汇率变动预测等。近年来,随着深度学习技术的发展,尤其是循环神经网络(Recurrent Neural Networks, RNNs)及其变体如长短期记忆网络(LSTM)和门控循环单元(GRU),在处理时间序列数据方面展现出了巨大的潜力。本文将探讨RNN的基本概念,并通过具体的代码示例展示如何使用这些模型来进行金融数据分析。
1277 2
|
机器学习/深度学习 自然语言处理 算法
RNN-循环神经网络
自然语言处理(Nature language Processing, NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言,我们在进行文本数据处理时,需要将文本进行数据值化,然后进行后续的训练工作。

热门文章

最新文章