LLM 加速技巧:Muti Query Attention

本文涉及的产品
实时计算 Flink 版,5000CU*H 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。

在介绍MQA 之前,我们先回顾一下传统的多头注意力

Multi-Head Attention(MHA)

多头注意力是transformer 模型的默认注意力机制,如下图所示:

在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。

但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。

下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:

 defMHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bhmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bhmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

其中:

X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]

P_q, P_k:查询和键投影张量,形状为[h, d, k]

P_v:值投影张量,形状为[h, d, v]

P_o:学习到的线性投影,形状为[h, d, v]

Prev_K:上一步的关键张量,形状为[b, h, m, k]

Prev_V:前一步的Value张量,形状为[b, h, m, v]

new_K:加上当前步的键张量,形状为[b, h, m+1, k]

new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]

维度表示如下:

M:先前执行的步骤数

B:批量大小

D:输入和输出的尺寸

H:注意力头数

k:Q,K张量的另一个维度

v: v张量的另一个维度

Multi-Query Attention(MQA)

MQA是多头注意的一种变体。

MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:

从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:

 defMQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目

 classMultiheadAttention(nn.Module):

     def__init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
         ):
         """
         Multi Head init func.

         Args:
             d_model (int): hidden state size, e.g. 768
             n_heads (int): 设定的注意力头数, e.g. 8
             device (str): _description_
         """
         super().__init__()

         self.d_model=d_model
         self.n_heads=n_heads

         self.Wqkv=nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model, 
             3*self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
         )                                            # (d_model, 3 * d_model)
         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )

     defforward(
         self,
         x
     ):
         """
         forward func.

         Args:
             x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)

         Returns:
             _type_: _description_
         """
         qkv=self.Wqkv(x)                            # (1, 768, 3 * 768)

         query, key, value=qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3, 
             dim=2
         )     

         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads
         )                                             # (1, 512, 768)

         returnself.out_proj(context), attn_weights, past_key_value


 classMultiQueryAttention(nn.Module):
     """Multi-Query self attention.

     Using torch or triton attention implemetation enables user to also use
     additive bias.
     """

     def__init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] =None,
     ):
         super().__init__()

         self.d_model=d_model
         self.n_heads=n_heads
         self.head_dim=d_model//n_heads

         self.Wqkv=nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model+2*self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
         )

         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )
         self.out_proj._is_residual=True  # type: ignore

     defforward(
         self,
         x,
     ):
         qkv=self.Wqkv(x)                                           # (1, 512, 960)

         query, key, value=qkv.split(                               # query -> (1, 512, 768)
             [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
         )

         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
         )

         returnself.out_proj(context), attn_weights, past_key_value

从代码中可以看到所有 头之间共享一份 key 和 value 的参数,但是如何将这 1 份参数同时让 8 个头都使用呢?

代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享,主要是这个函数:scaled_multihead_dot_product_attention

 defscaled_multihead_dot_product_attention(
         query,
         key,
         value,
         n_heads,
         past_key_value=None,
         softmax_scale=None,
         attn_bias=None,
         key_padding_mask=None,
         is_causal=False,
         dropout_p=0.0,
         training=False,
         needs_weights=False,
         multiquery=False,
     ):
     q=rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
     kv_n_heads=1ifmultiqueryelsen_heads
     k=rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
     v=rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery

     attn_weight=q.matmul(k) *softmax_scale                       # (1, 8, 512, 512)
     attn_weight=torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)

     out=attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
     out=rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)

     returnout, attn_weight, past_key_value

MQA指标测试

MQA能在多大程度上提高速度?让我们看看原文中提供的结果图表:

从上表可以看出,MQA在编码器上的速度提升不是很显著,但在解码器上的速度提升是相当显著的。

论文中也有关于质量的实验,结果表明MQA的性能与基线相比只是稍微低一些。降低应该是肯定的因为毕竟共享了参数,但是只要再可接受范围内并且能够大量提升速度这个降低就是可以接受的,对吧。

为什么MQA可以实现推理加速?

在MQA中,键张量和值张量的大小分别为b k和b v,而在MHA中,键张量和值张量的大小分别为b h k和b h v,其中h表示头的个数。

MQA通过以下方法实现推理加速:

1、KV缓存大小减少了h(头数量),这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批大小,从而提高效率。

2、减少了从内存中读取的数据量,从而减少了计算单元的等待时间,提高了计算利用率。

3、MQA有一个相对较小的KV数量,可以放入缓存(SRAM)中。MHA则需要较大的KV数量,不能完全存储在缓存中,需要从GPU内存(DRAM)读取,这很耗时。

总结

MQA是在2019年提出的,当时的应用还没有那么广泛。这是因为以前的模型不需要关心这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

当transformer最初被提出时,它主要用于Seq2Seq任务,特别是在Encoder-Decoder模型中。由于模型的规模不是很大,也并且没有太多的实际需求,所以MQA并没有引起太多的关注。

直到近年来(尤其是2023年开始)基于transformer的大型语言模型(如GPT)得到广泛应用后,推理的瓶颈才被人们重视。所以MQA才被发现非常有用,这主要是由于对大规模gpt式生成模型的实际需求。

最后我们再回顾以下这个论文:

https://avoid.overfit.cn/post/877de0f5a56d478d8133d75a05064e7e

作者:Florian June

目录
相关文章
|
机器学习/深度学习 人工智能 自然语言处理
LLM系列 | 11: 基于ChatGPT构建智能客服系统(query分类&安全检查&防注入)
本文主要介绍如何使用ChatGPT对智能客服领域中的客户咨询进行分类。此外还补充构建真实应用中如何对用户咨询内容和模型生成内容进行安全检查及其如何预防用户注入。
|
24天前
|
前端开发 机器人 API
前端大模型入门(一):用 js+langchain 构建基于 LLM 的应用
本文介绍了大语言模型(LLM)的HTTP API流式调用机制及其在前端的实现方法。通过流式调用,服务器可以逐步发送生成的文本内容,前端则实时处理并展示这些数据块,从而提升用户体验和实时性。文章详细讲解了如何使用`fetch`发起流式请求、处理响应流数据、逐步更新界面、处理中断和错误,以及优化用户交互。流式调用特别适用于聊天机器人、搜索建议等应用场景,能够显著减少用户的等待时间,增强交互性。
174 2
|
18天前
|
机器学习/深度学习 人工智能 运维
企业内训|LLM大模型在服务器和IT网络运维中的应用-某日企IT运维部门
本课程是为某在华日资企业集团的IT运维部门专门定制开发的企业培训课程,本课程旨在深入探讨大型语言模型(LLM)在服务器及IT网络运维中的应用,结合当前技术趋势与行业需求,帮助学员掌握LLM如何为运维工作赋能。通过系统的理论讲解与实践操作,学员将了解LLM的基本知识、模型架构及其在实际运维场景中的应用,如日志分析、故障诊断、网络安全与性能优化等。
43 2
|
21天前
|
机器学习/深度学习 数据采集 人工智能
文档智能 & RAG 让AI大模型更懂业务 —— 阿里云LLM知识库解决方案评测
随着数字化转型的深入,企业对文档管理和知识提取的需求日益增长。阿里云推出的文档智能 & RAG(Retrieval-Augmented Generation)解决方案,通过高效的内容清洗、向量化处理、精准的问答召回和灵活的Prompt设计,帮助企业构建强大的LLM知识库,显著提升企业级文档管理的效率和准确性。
|
5天前
|
人工智能 自然语言处理 算法
政务培训|LLM大模型在政府/公共卫生系统的应用
本课程是TsingtaoAI公司面向某卫生统计部门的政府职员设计的大模型技术应用课程,旨在系统讲解大语言模型(LLM)的前沿应用及其在政府业务中的实践落地。课程涵盖从LLM基础知识到智能化办公、数据处理、报告生成、智能问答系统构建等多个模块,全面解析大模型在卫生统计数据分析、报告撰写和决策支持等环节中的赋能价值。
22 2
|
23天前
|
人工智能 自然语言处理 运维
前端大模型应用笔记(一):两个指令反过来说大模型就理解不了啦?或许该让第三者插足啦 -通过引入中间LLM预处理用户输入以提高多任务处理能力
本文探讨了在多任务处理场景下,自然语言指令解析的困境及解决方案。通过增加一个LLM解析层,将复杂的指令拆解为多个明确的步骤,明确操作类型与对象识别,处理任务依赖关系,并将自然语言转化为具体的工具命令,从而提高指令解析的准确性和执行效率。
|
22天前
|
人工智能 前端开发
大模型体验体验报告:OpenAI-O1内置思维链和多个llm组合出的COT有啥区别?传统道家理论+中学生物理奥赛题测试,名不虚传还是名副其实?
一个月前,o1发布时,虽然让人提前体验,但自己并未进行测试。近期终于有机会使用,却仍忘记第一时间测试。本文通过两个测试案例展示了o1的强大能力:一是关于丹田及练气的详细解答,二是解决一道复杂的中学生物理奥赛题。o1的知识面广泛、推理迅速,令人印象深刻。未来,或许可以通过赋予o1更多能力,使其在更多领域发挥作用。如果你有好的测试题,欢迎留言,一起探索o1的潜力。
|
23天前
|
机器学习/深度学习 人工智能 自然语言处理
前端大模型入门(三):编码(Tokenizer)和嵌入(Embedding)解析 - llm的输入
本文介绍了大规模语言模型(LLM)中的两个核心概念:Tokenizer和Embedding。Tokenizer将文本转换为模型可处理的数字ID,而Embedding则将这些ID转化为能捕捉语义关系的稠密向量。文章通过具体示例和代码展示了两者的实现方法,帮助读者理解其基本原理和应用场景。
132 1
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
企业内训|LLM大模型技术在金融领域的应用及实践-某商业银行分行IT团队
本企业培训是TsingtaoAI技术团队专们为某商业银行分行IT团队开发的LLM大模型技术课程。课程深入分析大模型在金融行业中的发展趋势、底层技术及应用场景,重点提升学员在大模型应用中的实际操作能力与业务场景适应力。通过对全球商用 LLM 产品及国内外技术生态的深度对比,学员将了解大模型在不同企业中的发展路径,掌握如 GPT 系列、Claude 系列、文心一言等大模型的前沿技术。针对金融行业的业务需求,学员将学会如何结合多模态技术改进用户体验、数据分析等服务流程,并掌握大模型训练与工具链的实操技术,尤其是模型的微调、迁移学习与压缩技术。
52 2
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
【AI大模型】LLM主流开源大模型介绍
【AI大模型】LLM主流开源大模型介绍