最高加速9倍!字节跳动开源8比特混合精度Transformer引擎(2)

简介: 最高加速9倍!字节跳动开源8比特混合精度Transformer引擎


量化策略


将一个浮点数矩阵量化为 int8 整数矩阵有很多方法,LightSeq 采用的是对称量化,即将正负数范围对称的浮点数区间等比例地映射到整数区间 [-127, 127] 上。


而实际上浮点数矩阵的数值范围通常并不对称,存在极少的离群值。如果直接按照离群值的范围来量化矩阵,会影响到量化后的精度,所以需要先对矩阵进行数值截断。


LightSeq 采用 PACT 方法进行截断[6],将截断的范围当作模型可学习的参数,然后利用 STE 算法去估计参数的梯度,并进行反向传播优化。根据实践经验,权重 weight 的初始截断范围设为[-1, 1],中间结果的初始截断范围设为[-16, 16],可以在大部分任务上达到最好的效果。最后经过截断范围和其他模型参数的联合优化,量化模型的效果可以达到基本无损。


梯度通信量化


针对分布式训练场景,LightSeq 推出了梯度量化压缩技术。即对浮点精度的梯度进行 int8 量化,以减少梯度通信的时间消耗,从而加速训练,这就是梯度通信量化(GCQ)。


如上图所示,梯度通信量化的主要流程如下:


  1. 计算每张卡上各自梯度的截断范围;
  2. 对截断范围执行 all-reduce max 操作;
  3. 每张卡使用统一的截断范围对各自梯度进行 int8 量化;
  4. 对 int8 梯度执行 all-reduce sum 操作;
  5. 每张卡对 all-reduce 后的梯度进行反量化,还原为浮点数梯度,并进行参数更新。


为了解决 int8 梯度在 all-reduce 过程中溢出的问题,LightSeq 首先将每张卡上的浮点数梯度除以卡数,再使用除之前的截断范围进行量化,最后进行 all-reduce 操作。这样每张卡上量化后的 int8 整数 all-reduce 完就不会溢出,但是单卡实际用于量化的比特数也因此而减少,所以目前方案在 2 机 8 卡效果几乎无损,但随着卡数的上涨,训练效果会有所下降。以 en2de 和 en2fr 翻译任务为例,在 4 机 8 卡上进行分布式量化训练,BLEU 值分别会下降 0.4 和 1.5 左右。未来 LightSeq 将会持续探索更好的方法来解决这一问题。


通用技术


除了上一章节中提到的量化技术以外,此次更新 LightSeq 还提出了几种通用的优化技术,不仅可以应用在量化模型中,也适用于其它所有精度模型的训练与推理。


算子融合


上图是 encoder 模块量化训练的计算图,LightSeq 将两次 GEMM 运算之间的所有操作融合成一个算子[7],减少了 kernel 调用的次数,因此减少了总的计算时间。


图中黄色矩形表示 int8 GEMM,绿色矩形表示 float GEMM。这里采用 float GEMM 是由于 shape 的限制,不适合使用 int8 GEMM 加速。红色箭头表示流动数据的类型是 int8,绿色箭头表示第二层 FFN 的 GEMM 输出是 int32 数据类型。int8 GEMM 输入输出的量化与反量化操作都被融合到了前后 kernel 里,这不仅可以减少数据搬运,还可以减小显存占用。


在推理时,LightSeq 还针对 decoder 做了优化。如上图所示,在计算 self-attention 时,注意力得分的维度是(batch size, 1, sequence length)。因此在计算 value 乘积时,可以不采用 GEMM 运算,而直接手写加权求和的算子,从而将图中虚线框中的计算融合成一个 kernel。


自动显存管理


模型量化引入了更复杂的张量类型和张量依赖关系,这给显存管理带来新的挑战。为此,LightSeq 设计了新的显存管理机制。如上图所示,主要包括以下过程:


  1. 训练启动前,根据每个算子的拓扑依赖关系,自动计算每个张量的生命周期及显存空间大小。其中,包含动态维度的张量按照此维度的最大量进行计算,例如机器翻译任务中的最大句长和最大 batch 句子数量。这些最大量在训练前已被指定;
  2. 张量确定生命周期和大小后,分析显存复用关系。其中,无生命周期重合的张量可以共用一片显存空间,所有显存空间都是无数据类型的,可以被分配到任意数据类型的张量上;
  3. 根据张量显存复用关系,申请多段显存空间,为每个张量分配实际的显存起止地址。


张量显存复用的分析,LightSeq 借鉴了论文 [3] 中提出的 Greedy by Size for Offset Calculation 方法,做了三个改进:


  1. 支持了整个训练过程的显存复用(forward/backward);
  2. 不同数据类型能做到显存复用(int8/fp16/fp32);
  3. 在多段显存空间上容纳所有张量,而非一段非常大的显存空间,这样能有效提升显存利用率。


自动 GEMM 调优


LightSeq 的 int8 GEMM 采用了 NVIDIA 的 cuBLASLt 库,这也是目前 NVIDIA 显卡上最为高效的矩阵运算库。但是输入数据的 shape 或者显卡不同的话,GEMM 所采用的最优配置(例如数据排布、GEMM 算法等等)也可能不同,因此需要进行自动选取。LightSeq 采取的自动调优方案如下:


  1. 在多种型号显卡上(例如 T4 和 A100)进行不同 shape 的 GEMM 最优配置搜索,并将结果保存到配置文件中,用户只需要下载即可;
  2. 模型初始化时,加载对应型号显卡的配置文件,解析并保存到键值对为 (shape, 最优配置) 的字典中。如果没有对应型号显卡的配置文件,或者没有需要的 GEMM shape,那么用户可以选择自己搜索并保存,或者直接使用默认配置;
  3. 模型前向或后向计算时,根据输入的 shape 在字典中寻找最优配置,然后进行 GEMM 计算。如果没有找到对应的 shape,那么直接采用默认的配置。


未来工作


未来 LightSeq 还将继续探索移动端的低精度量化、反向传播中梯度的量化、大模型量化等方向。


引用


[1] Wang, Xiaohui, et al. "LightSeq2: Accelerated training for transformer-based models on gpus." arXiv preprint arXiv:2110.05722 (2021).

[2] Micikevicius, Paulius, et al. "Mixed precision training." arXiv preprint arXiv:1710.03740 (2017).

[3] Pisarchyk, Yury, and Juhyun Lee. "Efficient memory management for deep neural net inference." arXiv preprint arXiv:2001.03288 (2020).

[4] Jacob, Benoit, et al. "Quantization and training of neural networks for efficient integer-arithmetic-only inference." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.

[5] Alistarh, Dan, et al. "QSGD: Communication-efficient SGD via gradient quantization and encoding." Advances in neural information processing systems 30 (2017).

[6] Choi, Jungwook, et al. "Pact: Parameterized clipping activation for quantized neural networks." arXiv preprint arXiv:1805.06085 (2018).

[7] Wang, Xiaohui, et al. "LightSeq: A high performance inference library for transformers." arXiv preprint arXiv:2010.13887 (2020).

相关文章
|
存储 Prometheus 监控
CentOS7下简单搭建Prometheus+Grafana监控系统(上)
CentOS7下简单搭建Prometheus+Grafana监控系统
1047 0
CentOS7下简单搭建Prometheus+Grafana监控系统(上)
|
10月前
|
机器学习/深度学习 人工智能 自然语言处理
ICLR 2025 | EDiT:一种基于 Local SGD 策略的大模型高效分布式训练方法
蚂蚁 AI Infra 团队在深度学习最核心之一的训练框架方向上持续投入与创新,实现了提升资源利用率、加速训练、提升训练稳定性等目标。我们提出的 EDiT 方法,即为其中一项工作。
|
12月前
|
小程序
微信小程序数据绑定与事件处理:打造动态交互体验
在上一篇中,我们学习了搭建微信小程序开发环境并创建“Hello World”页面。本文深入探讨数据绑定和事件处理机制,通过具体案例帮助你打造更具交互性的小程序。数据绑定使用双花括号`{{}}`语法,实现页面与逻辑层数据的动态关联;事件处理则通过`bind`或`catch`前缀响应用户操作。最后,通过一个简单的计数器案例,巩固所学知识。掌握这些核心技能,将助你开发更复杂的小程序。
|
Dart UED 开发者
flutter鸿蒙版本通过底部导航栏的实现熟悉架构及语法
这篇博客详细解析了一个 Flutter 应用的完整代码,实现了带有底部导航栏的功能,允许用户在不同页面之间切换。通过逐行讲解,帮助读者理解 Flutter 的结构、状态管理和组件交互。代码涵盖了从引入包、创建主入口、定义无状态和有状态组件,到构建用户界面的全过程。希望对 Flutter 开发者有所帮助。
359 3
|
机器学习/深度学习 存储 算法
关于深度学习量化的操作
0. 简介 深度学习中做量化提升运行速度是最常用的方法,尤其是大模型这类非常吃GPU显存的方法。一般是高精度浮点数表示的网络权值以及激活值用低精度(例如8比特定点)来近似表示达到模型轻量化,加速深度学习模型推理,目前8比特推理已经比较成熟。比如int8量化,就是让原来32bit存储的数字映射到8bit存储。int8范围是[-128,127], uint8范围是[0,255]。 使用低精度的模型推理的优点:1. 模型存储主要是每个层的权值,量化后模型占用空间小,32比特可以缩减至8比特,并且激活值用8比特后,减小了内存的访问带宽需求。2:单位时间内处理定点运算指令比浮点数运算指令多。 1.
390 12
|
监控
DDN是什么,DDN专线的优势详解
数字数据产品是基于数字数据网(Digital Data Network,简称DDN),利用数字信道提供永久性或半永久性连接电路来传输数据信号的服务。
874 0
|
缓存 Ubuntu Linux
在Linux中,如何进行系统更新和升级?
在Linux中,如何进行系统更新和升级?
|
安全 算法 Java
JDK 9新特性:增强的加密算法支持
本文将深入探讨JDK 9中增强的加密算法支持这一新特性。随着网络安全威胁的日益严重,加密算法在保障数据安全方面起着至关重要的作用。JDK 9通过引入更多高效、安全的加密算法,提升了Java应用程序的加密能力。本文将详细介绍这些新加密算法的特点,以及如何在实际项目中应用这些新特性来提高数据的安全性。
|
机器学习/深度学习 存储 PyTorch
PyTorch应用实战五:实现二值化神经网络
PyTorch应用实战五:实现二值化神经网络
1224 0

热门文章

最新文章