AWD-LSTM为什么这么棒?

简介: AWD-LSTM为什么这么棒,看完你就明白啦!

AWD-LSTM是目前最优秀的语言模型之一。在众多的顶会论文中,对字级模型的研究都采用了AWD-LSTMs,并且它在字符级模型中的表现也同样出色。

本文回顾了论文——Regularizing and Optimizing LSTM Language Models ,在介绍AWD-LSTM模型的同时并解释其中所涉及的各项策略。该论文提出了一系列基于词的语言模型的正则化和优化策略。这些策略不仅行之有效,而且能够在不改变现有LSTM模型的基础上使用。

AWD-LSTM即ASGD Weight-Dropped LSTM。它使用了DropConnect及平均随机梯度下降的方法,除此之外还有包含一些其它的正则化策略。我们将在后文详细讲解这些策略。本文将着重于介绍它们在语言模型中的成功应用。

实验代码获取:awd-lstm-lm GitHub repository

LSTM中的数学公式:

it = σ(Wixt + Uiht-1)
ft = σ(Wfxt + Ufht-1)
ot = σ(Woxt + Uoht-1)
c’t = tanh(Wcxt + Ucht-1)
ct = it ⊙ c’t + ft ⊙ c’t-1
ht = ot ⊙ tanh(ct)


其中, Wi, Wf, Wo, Wc, Ui, Uf, Uo, Uc都是权重矩阵, xt表示输入向量, ht表示隐藏单元向量, ct表示单元状态向量, 表示element-wise乘法。
接下来我们将逐一介绍作者提出的策略:

权重下降的LSTM

RNN的循环连接容易导致过拟合问题,如何解决这一问题也成了一个较为热门的研究领域。Dropouts的引入在前馈神经网络和卷积网络中取得了巨大的成功。但将Dropouts引入到RNN中却反响甚微,这是由于Dropouts的加入破坏了RNN长期依赖的能力。

研究学者们就此提出了许多解决方案,但是这些方法要么作用于隐藏状态向量ht-1,要么是对单元状态向量ct进行更新。上述操作能够解决高度优化的“黑盒”RNN,例如NVIDIA’s cuDNN LSTM中的过拟合问题。

但仅如此是不够的,为了更好的解决这个问题,研究学者们引入了DropConnect。DropConnect是在神经网络中对全连接层进行规范化处理。Dropout是指在模型训练时随机的将隐层节点的权重变成0,暂时认为这些节点不是网络结构的一部分,但是会把它们的权重保留下来。与Dropout不同的是DropConnect在训练神经网络模型过程中,并不随机的将隐层节点的输出变成0,而是将节点中的每个与其相连的输入权值以1-p的概率变成0。

screenshot
Regularization of Neural Networks using DropConnect

DropConnect作用在hidden-to-hidden权重矩阵(Ui、Uf、Uo、Uc)上。在前向和后向遍历之前,只执行一次dropout操作,这对训练速度的影响较小,可以用于任何标准优化的“黑盒”RNN中。通过对hidden-to-hidden权重矩阵进行dropout操作,可以避免LSTM循环连接中的过度拟合问题。

你可以在 awd-lstm-lm 中找到weight_drop.py 模块用于实现。

作者表示,尽管DropConnect是通过作用在hidden-to-hidden权重矩阵以防止过拟合问题,但它也可以作用于LSTM的非循环权重。

使用非单调条件来确定平均触发器

研究发现,对于特定的语言建模任务,传统的不带动量的SGD算法优于带动量的SGD、Adam、Adagrad及RMSProp等算法。因此,作者基于传统的SGD算法提出了ASGD(Average SGD)算法。

Average SGD

ASGD算法采用了与SGD算法相同的梯度更新步骤,不同的是,ASGD没有返回当前迭代中计算出的权值,而是考虑的这一步和前一次迭代的平均值。

传统的SGD梯度更新:

w_t = w_prev - lr_t * grad(w_prev)

AGSD梯度更新:

avg_fact = 1 / max(t - K, 1)
if avg_fact != 1:
  w_t = avg_fact * (sum(w_prevs) + (w_prev - lr_t * grad(w_prev)))
else:
  w_t = w_prev - lr_t * grad(w_prev)

其中,k是在加权平均开始之前运行的最小迭代次数。在k次迭代开始之前,ASGD与传统的SGD类似。t是当前完成的迭代次数,sum(w_prevs)是迭代k到t的权重之和,lr_t是迭代次数t的学习效率,由学习率调度器决定。

你可以在这里找到AGSD的PyTorch实现。

但作者也强调,该方法有如下两个缺点:

  • 学习率调度器的调优方案不明确
  • 如何选取合适的迭代次数k。值太小会对方法的有效性产生负面影响,值太大可能需要额外的迭代才能收敛。

基于此,作者在论文中提出了使用非单调条件来确定平均触发器,即NT-ASGD,其中:

  • 当验证度量不能改善多个循环时,就会触发平均值。这是由非单调区间的超参数n保证的。因此,每当验证度量没有在n个周期内得到改进时,就会使用到ASGD算法。通过实验发现,当n=5的时候效果最好。
  • 整个实验中使用一个恒定的学习速率,不需要进一步的调整。

正则化方法

除了上述提及的两种方法外,作者还使用了一些其它的正则化方法防止过拟合问题及提高数据效率。

长度可变的反向传播序列

作者指出,使用固定长度的基于时间的反向传播算法(BPTT)效率较低。试想,在一个时间窗口大小固定为10的BPTT算法中,有100个元素要进行反向传播操作。在这种情况下,任何可以被10整除的元素都不会有可以反向支撑的元素。这导致了1/10的数据无法以循环的方式进行自我改进,8/10的数据只能使用到部分的BPTT窗口。

为了解决这个问题,作者提出了使用可变长度的反向传播序列。首先选取长度为bptt的序列,概率为p以及长度为bptt/2的序列,概率为1-p。在PyTorch中,作者将p设为0.95。

base_bptt = bptt if np.random.random() < 0.95 else bptt / 2

其中,base_bptt用于获取seq_len,即序列长度,在N(base_bptt, s)中,s表示标准差,N表示服从正态分布。代码如下:

seq_len = max(5, int(np.random.normal(base_bptt, 5)))

学习率会根据seq_length进行调整。由于当学习速率固定时,会更倾向于对段序列而非长序列进行采样,所以需要进行缩放。

lr2 = lr * seq_len / bptt

Variational Dropout

在标准的Dropout中,每次调用dropout连接时都会采样到一个新的dropout mask。而在Variational Dropout中,dropout mask在第一次调用时只采样一次,然后locked dropout mask将重复用于前向和后向传播中的所有连接。

虽然使用了DropConnect而非Variational Dropout以规范RNN中hidden-to-hidden的转换,但是对于其它的dropout操作均使用的Variational Dropout,特别是在特定的前向和后向传播中,对LSTM的所有输入和输出使用相同的dropout mask。

点击查看官方awd-lstm-lm GitHub存储库的Variational dropout实现。详情请参阅原文

Embedding Dropout

论文中所提到的Embedding Dropout首次出现在——《A Theoretically Grounded Application of Dropout in Recurrent Neural Networks》一文中。该方法是指将dropout作用于嵌入矩阵中,且贯穿整个前向和反向传播过程。在该过程中出现的所有特定单词均会消失。

Weight Tying(权重绑定)

权重绑定共享嵌入层和softmax层之间的权重,能够减少模型中大量的参数。

Reduction in Embedding Size

对于语言模型来说,想要减少总参数的数量,最简单的方法是降低词向量的维数。即使这样无法帮助缓解过拟合问题,但它能够减少嵌入层的维度。对LSTM的第一层和最后一层进行修改,可以使得输入和输出的尺寸等于减小后的嵌入尺寸。

Activation Regularization(激活正则化)

L2正则化是对权重施加范数约束以减少过拟合问题,它同样可以用于单个单元的激活,即激活正则化。激活正则化可作为一种调解网络的方法。

loss = loss + alpha * dropped_rnn_h.pow(2).mean()

Temporal Activation Regularization(时域激活正则化)

同时,L2正则化能对RNN在不同时间步骤上的输出差值进行范数约束。它通过在隐藏层产生较大变化对模型进行惩罚。

loss = loss + beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()

其中,alpha和beta是缩放系数,AR和TAR损失函数仅对RNN最后一层的输出起作用。

模型分析

作者就上述模型在不同的数据集中进行了实验,为了对分分析,每次去掉一种策略。

screenshot

图中的每一行表示去掉特定策略的困惑度(perplexity)分值,从该图中我们能够直观的看出各策略对结果的影响。

实验细节

数据——来自Penn Tree-bank(PTB)数据集和WikiText-2(WT2)数据集。

网络体系结构
——所有的实验均使用的是3层LSTM模型。

批尺寸——WT2数据集的批尺寸为80,PTB数据集的批尺寸为40。根据以往经验来看,较大批尺寸(40-80)的性能优于较小批尺寸(10-20)。

其它超参数的选择请参考原文

总结

该论文很好的总结了现有的正则化及优化策略在语言模型中的应用,对于NLP初学者甚至研究者都大有裨益。论文中强调,虽然这些策略在语言建模中获得了成功,但它们同样适用于其他序列学习任务。

相关文章
|
资源调度 算法 JavaScript
Python基础专题 - 超级详细的 Random(随机)原理解析与编程实践
Python基础专题 - 超级详细的 Random(随机)原理解析与编程实践
1665 0
|
SQL XML Java
8、Mybatis-Plus 分页插件、自定义分页
这篇文章介绍了Mybatis-Plus的分页功能,包括如何配置分页插件、使用Mybatis-Plus提供的Page对象进行分页查询,以及如何在XML中自定义分页SQL。文章通过具体的代码示例和测试结果,展示了分页插件的使用和自定义分页的方法。
8、Mybatis-Plus 分页插件、自定义分页
|
10月前
|
机器学习/深度学习 人工智能 测试技术
MoBA:LLM长文本救星!月之暗面开源新一代注意力机制:处理1000万token能快16倍,已在Kimi上进行验证
MoBA 是一种新型注意力机制,通过块稀疏注意力和无参数门控机制,显著提升大型语言模型在长上下文任务中的效率。
607 3
|
人工智能 Shell iOS开发
AI Shell:在命令行里“对话” AI ,微软推出将 AI 助手引入命令行的 CLI 工具,打造对话式交互命令行
AI Shell 是一款强大的 CLI 工具,将人工智能直接集成到命令行中,帮助用户提高生产力。AI Shell 支持多种 AI 模型和助手,通过多代理框架提供丰富的功能和灵活的使用模式。
1765 7
|
机器学习/深度学习 人工智能 监控
一文读懂deepSpeed:深度学习训练的并行化
DeepSpeed 是由微软开发的开源深度学习优化库,旨在提高大规模模型训练的效率和可扩展性。通过创新的并行化策略、内存优化技术(如 ZeRO)及混合精度训练,DeepSpeed 显著提升了训练速度并降低了资源需求。它支持多种并行方法,包括数据并行、模型并行和流水线并行,同时与 PyTorch 等主流框架无缝集成,提供了易用的 API 和丰富的文档支持。DeepSpeed 不仅大幅减少了内存占用,还通过自动混合精度训练提高了计算效率,降低了能耗。其开源特性促进了 AI 行业的整体进步,使得更多研究者和开发者能够利用先进优化技术,推动了 AI 在各个领域的广泛应用。
|
机器学习/深度学习 人工智能 算法
一文搞懂模型量化算法基础
一文搞懂模型量化算法基础
5273 0
|
分布式计算 C语言 Python
基于Python实现MapReduce
一、什么是MapReduce 首先,将这个单词分解为Map、Reduce。 • Map阶段:在这个阶段,输入数据集被分割成小块,并由多个Map任务处理。每个Map任务将输入数据映射为一系列(key, value)对,并生成中间结果。 • Reduce阶段:在这个阶段,中间结果被重新分组和排序,以便相同key的中间结果被传递到同一个Reduce任务。每个Reduce任务将具有相同key的中间结果合并、计算,并生成最终的输出。
|
机器学习/深度学习 并行计算 关系型数据库
【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models
【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models
568 1
|
机器学习/深度学习 数据采集 自然语言处理
【传知代码】BERT论文解读及情感分类实战-论文复现
本文介绍了BERT模型的架构和技术细节,包括双向编码器、预训练任务(掩码语言模型和下一句预测)以及模型微调。文章还提供了使用BERT在IMDB数据集上进行情感分类的实战,包括数据集处理、模型训练和评估,测试集准确率超过93%。BERT是基于Transformer的预训练模型,适用于多种NLP任务。在实践中,BERT模型加载预训练权重,对输入数据进行预处理,然后通过微调适应情感分类任务。
1157 0
【传知代码】BERT论文解读及情感分类实战-论文复现
|
存储 运维 Java
Java中的字节码与JVM指令集详解
Java中的字节码与JVM指令集详解