深入剖析Transformer架构中的多头注意力机制

本文涉及的产品
文本翻译,文本翻译 100万字符
文档翻译,文档翻译 1千页
NLP 自学习平台,3个模型定制额度 1个月
简介: 多头注意力机制(Multi-Head Attention)是Transformer模型中的核心组件,通过并行运行多个独立的注意力机制,捕捉输入序列中不同子空间的语义关联。每个“头”独立处理Query、Key和Value矩阵,经过缩放点积注意力运算后,所有头的输出被拼接并通过线性层融合,最终生成更全面的表示。多头注意力不仅增强了模型对复杂依赖关系的理解,还在自然语言处理任务如机器翻译和阅读理解中表现出色。通过多头自注意力机制,模型在同一序列内部进行多角度的注意力计算,进一步提升了表达能力和泛化性能。

引言

在上一篇文章中,我们学习了自注意力机制,今天就来接着学习多头注意力机制。

多头注意力机制

多头注意力(Multi-Head Attention)是一种在Transformer模型中被广泛采用的注意力机制扩展形式,它通过并行地运行多个独立的注意力机制来获取输入序列的不同子空间的注意力分布,从而更全面地捕获序列中潜在的多种语义关联。

image.png

在多头注意力中,输入序列首先通过三个不同的线性变换层分别得到Query、Key和Value。然后,这些变换后的向量被划分为若干个“头”,每个头都有自己独立的Query、Key和Value矩阵。对于每个头,都执行一次Scaled Dot-Product Attention(缩放点积注意力)运算,即:

多头注意力的计算可以表示为:

image.png

最后,所有头的输出会被拼接(concatenate)在一起,然后再通过一个线性层进行融合,得到最终的注意力输出向量。

通过这种方式,多头注意力能够并行地从不同的角度对输入序列进行注意力处理,提高了模型理解和捕捉复杂依赖关系的能力。在实践中,多头注意力能显著提升Transformer模型在自然语言处理和其他序列数据处理任务上的性能。

多头注意力机制原理

image.png

输入变换与线性投影

多头注意力机制的输入变换与线性投影是其核心步骤之一。给定输入序列,首先通过三个不同的线性变换层生成查询(Query)、键(Key)和值(Value)矩阵。这些变换通常是通过全连接层实现的,其目的是将输入数据映射到不同的表示子空间中,为后续的注意力计算提供基础。

查询(Q)、键(K)和值(V)的生成

输入序列首先被映射到查询、键和值矩阵。这一步骤通过与权重矩阵WQ、WK和WV的矩阵乘法实现,其中每个矩阵都是模型中的可学习参数。数学上,这可以表示为:

image.png

线性投影的作用

线性投影不仅帮助模型将输入数据映射到不同的表示空间,而且还允许模型学习如何根据当前任务的需要动态地聚焦于输入数据的不同部分。这种动态聚焦是通过计算输入数据的加权表示来实现的,权重由模型学习得到。

分头计算与并行处理

多头注意力机制将查询、键和值矩阵分成多个头(即多个子空间),每个头具有不同的线性变换参数。每个头独立地计算注意力得分,并生成一个注意力加权后的输出。这些输出随后被合并,形成一个最终的、更复杂的表示。

分头计算

在多头注意力中,查询、键和值的线性变换实际上会进行多次,每个头都有自己的权重矩阵。这样,输入向量被分割到多个不同的子空间中,每个子空间执行自注意力操作。公式上表现为:

image.png

并行处理

由于每个头的计算是独立的,这些计算可以并行进行,从而提高模型的计算效率。这种并行性使得多头注意力机制在处理长序列数据时更加高效。

注意力权重计算

在多头注意力机制中,每个头的注意力权重计算是通过缩放点积注意力(Scaled Dot-Product Attention)实现的。具体来说,计算查询和键的点积,经过缩放、加上偏置后,使用softmax函数得到注意力权重。

缩放点积注意力

为了避免过大的点积导致梯度消失问题,通常会对点积结果进行缩放。缩放因子通常是键向量维度的倒数或平方根:

image.png

归一化注意力权重

使用softmax函数对缩放后的得分进行归一化,得到每个元素的注意力权重,这些权重之和为1:

image.png

拼接与融合

多头注意力机制的最后步骤是将所有头的输出拼接在一起,然后通过一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。

拼接

将所有头的输出拼接在一起,形成一个长向量。这一步骤整合了不同子空间学到的信息,增强模型的表达能力。

融合

对拼接后的向量进行一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。这一步骤对应着:

image.png

为什么使用多个注意力头?

1、增加模型的学习能力和表达力:通过多个注意力头,模型可以学习到更丰富的上下文信息,每个头可能关注输入的不同特征,这些特征综合起来可以更全面地理解和处理输入序列。

2、提高模型性能:实验证明,多头注意力机制相较于单头注意力,往往能带来性能提升。这是因为模型可以通过并行处理和集成多个注意力头的结果,从不同角度捕捉数据的多样性,增强了模型对复杂序列任务的理解和泛化能力。

多头自注意力(Multi-Head Self-Attention)的常见误区

多头自注意力(Multi-Head Self-Attention)是多头注意力的一种,都属于注意力机制在深度学习中的应用,尤其是自然语言处理(NLP)领域的Transformer模型中。

自注意力就是Q=K=V?

image.png

多头自注意力与多头注意力的区别

应用场景

多头注意力不仅限于自注意力场景,它可以应用于任何形式的注意力机制,包括但不限于跨序列的注意力,比如在一个序列上对另一个序列的注意力(Cross-Attention)。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by num_heads"

        # Linear layers for Q, K, V transformations
        self.queries = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.values = nn.Linear(embed_size, embed_size)

        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value, mask=None):
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]

        # Transformations
        queries = self.queries(query).view(N, query_len, self.num_heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.num_heads, self.head_dim)
        values = self.values(value).view(N, value_len, self.num_heads, self.head_dim)

        # Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
        out = out.view(N, query_len, self.embed_size)

        return self.fc_out(out)

# Example usage for cross-sequence attention:
embed_size = 512
num_heads = 8
multi_head_attention = MultiHeadAttention(embed_size, num_heads)

# Dummy input tensors for demonstration
query = torch.rand((64, 10, embed_size))  # Target language sequence
key = torch.rand((64, 15, embed_size))    # Source language sequence
value = torch.rand((64, 15, embed_size))  # Source language sequence

output = multi_head_attention(query, key, value, mask=None)
print(output.shape)  # Should print: torch.Size([64, 10, 512])

多头自注意力特指在同一序列内部,每个元素对其它所有元素的注意力机制进行了多头处理,用于捕获序列内元素间的复杂依赖关系,常见于Transformer的编码器和解码器中。

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by num_heads"

        # Linear layers for Q, K, V transformations
        self.queries = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.values = nn.Linear(embed_size, embed_size)

        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N = x.shape[0]
        seq_len = x.shape[1]

        # Transformations
        queries = self.queries(x).view(N, seq_len, self.num_heads, self.head_dim)
        keys = self.keys(x).view(N, seq_len, self.num_heads, self.head_dim)
        values = self.values(x).view(N, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
        out = out.view(N, seq_len, self.embed_size)

        return self.fc_out(out)

# Example usage for text comprehension:
embed_size = 512
num_heads = 8
multi_head_self_attention = MultiHeadSelfAttention(embed_size, num_heads)

# Dummy input tensor for demonstration
x = torch.rand((64, 20, embed_size))  # Input sequence (could be source language or target language)

output = multi_head_self_attention(x, mask=None)
print(output.shape)  # Should print: torch.Size([64, 20, 512])

功能聚焦点

多头注意力可以用来同时考虑多种类型的关联性,无论是否是同一序列内的元素间相互作用。

多头自注意力特别强调的是序列自身的自参照特性,即序列的每一个位置都能查看整个序列并据此调整自身的表现形式。

凝练一下就是,多头注意力是一个通用术语,当应用于序列本身时,就成为多头自注意力。两者都是为了通过并行处理多个注意力视角来增强模型的表达能力和捕捉

多头自注意力计算过程

image.png

计算过程详细分的话,主要分为六个步骤:

输入向量

这是输入序列中第i个元素的向量表示,例如在自然语言处理中,它可以是一个词的嵌入向量。

image.png

线性变换

每个输入向量image.png通过三组权重矩阵image.png转换成查询(Query) 向量image.png,键 (Key) 向量image.png,和值 (Value) 向量image.png。具体地,对于每个头 h:

image.png

这里image.png是每个头特有的参数矩阵。

多头机制

如图所示,有两个头,因此对于每个输入向量,我们会得到两组image.png向量。

注意力函数

每个头计算一个注意力得分,该得分决定了在计算输出时对每个值向量V的加权重要性。注意力得分通常使用缩放的点积注意力来计算:

image.png

其中image.png是键向量的维度,用于缩放点积,以防止梯度消失。

头输出

每个头的输出image.png是通过应用注意力函数得到的加权和:

image.png

输出连接与变换

最后,所有头的输出被连接在一起形成一个单一的长向量,然后通过另一个权重矩阵image.png进行变换以产生最终的输出向量

image.png

这里的image.png表示连接操作,image.png是一个权重矩阵,用于将多头输出融合成最终的输出。

多头注意力及多头自注意力应用场景

多头注意力应用场景-机器翻译

假设你是一名国际商务顾问,经常需要处理来自不同国家的文件。今天你收到一份重要的合同,这份合同是用中文写的,而你需要将其准确无误地翻译成英文。为了确保翻译的质量,你决定使用一个基于多头注意力机制的机器翻译系统来辅助你完成这项任务。

在翻译过程中,这个系统不仅会逐字逐句地进行翻译,还会特别注意源语言(中文)句子中关键词汇、语法结构以及它们与目标语言(英文)对应部分的关系。比如,它会考虑“银行”这个词在不同的上下文中可能指代的是金融机构还是河岸,并据此选择最合适的翻译。

多头注意力的应用:

  • 词汇层面:一个“头”可以专注于捕捉词汇层面的对应关系,确保每个词都被正确翻译。

  • 语义层面:另一个“头”可能会更注重理解整体语义或上下文信息,从而帮助模型做出更明智的选择。

  • 长距离依赖:还有些“头”能够识别并处理句子中相隔较远元素之间的联系,这对于理解复杂的句子结构至关重要。

下面是基于PyTorch框架实现的多头注意力机制的一个简化版本,特别适用于上述提到的机器翻译任务。在这个例子中,我们假定查询(query)、键(key)和值(value)分别来自目标语言和源语言序列。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by num_heads"

        # Linear layers for Q, K, V transformations
        self.queries = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.values = nn.Linear(embed_size, embed_size)

        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value, mask=None):
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]

        # Transformations
        queries = self.queries(query).view(N, query_len, self.num_heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.num_heads, self.head_dim)
        values = self.values(value).view(N, value_len, self.num_heads, self.head_dim)

        # Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
        out = out.view(N, query_len, self.embed_size)

        return self.fc_out(out)

# Example usage for machine translation:
embed_size = 512
num_heads = 8
multi_head_attention = MultiHeadAttention(embed_size, num_heads)

# Dummy input tensors for demonstration
# Assume we have a batch of 64 sentences, each with up to 10 tokens in the target language (English),
# and up to 15 tokens in the source language (Chinese).
query = torch.rand((64, 10, embed_size))  # Target language sequence
key = torch.rand((64, 15, embed_size))    # Source language sequence
value = torch.rand((64, 15, embed_size))  # Source language sequence

output = multi_head_attention(query, key, value, mask=None)
print(output.shape)  # Should print: torch.Size([64, 10, 512])

query代表目标语言(如英文)的嵌入表示,而key和value则来自源语言(如中文)。通过多头注意力机制,模型可以在翻译过程中同时关注源语言的不同方面,从而生成更准确、流畅的目标语言文本。这种机制对于提高机器翻译的质量非常关键,尤其是在处理复杂句子结构或具有多种含义的词汇时。

多头自注意力应用场景-阅读理解

此处还是以开头的看小说为例展开哈。想象你正在阅读一本小说,小说中的人物关系错综复杂,情节发展跌宕起伏。为了更好地理解和享受这本小说,你需要同时关注多个层面的信息:角色之间的对话、情感变化、故事情节的发展以及背景设定等。在这种情况下,你的大脑会自动选择性地关注某些重要信息,并忽略其他不重要的信息,以便更有效地处理和理解文本内容。

多头自注意力的应用:

  • 局部细节:一个“头”可以专注于短语或句子内部的单词之间关系,帮助理解具体的语法结构。

  • 全局上下文:另一个“头”可能会着眼于整个段落甚至章节之间的联系,以把握故事的整体走向。

  • 人物互动:还有些“头”专门用于分析人物对话及其背后的情感色彩,有助于构建更加立体的角色形象。

下面是一个基于PyTorch框架实现的多头自注意力机制的简化版本。在这个例子中,查询(query)、键(key)和值(value)都来自同一个输入序列,即源语言序列本身。这种设计允许模型在同一位置上同时考虑序列中所有其他位置的信息,从而提高其捕捉长距离依赖关系的能力。

import torch
import torch.nn as nn
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by num_heads"

        # Linear layers for Q, K, V transformations
        self.queries = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.values = nn.Linear(embed_size, embed_size)

        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N = x.shape[0]
        seq_len = x.shape[1]

        # Transformations
        queries = self.queries(x).view(N, seq_len, self.num_heads, self.head_dim)
        keys = self.keys(x).view(N, seq_len, self.num_heads, self.head_dim)
        values = self.values(x).view(N, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
        out = out.view(N, seq_len, self.embed_size)

        return self.fc_out(out)

# Example usage for text comprehension:
embed_size = 512
num_heads = 8
multi_head_self_attention = MultiHeadSelfAttention(embed_size, num_heads)

# Dummy input tensor for demonstration
# Assume we have a batch of 64 sentences, each with up to 20 tokens.
x = torch.rand((64, 20, embed_size))  # Input sequence (could be source language or target language)

output = multi_head_self_attention(x, mask=None)
print(output.shape)  # Should print: torch.Size([64, 20, 512])

在这个代码示例中,x代表输入序列(可以是源语言或目标语言),并且查询(query)、键(key)和值(value)都是从同一个输入序列中派生出来的。通过多头自注意力机制,模型可以在处理每个位置时同时考虑该位置与其他所有位置之间的关系,从而更全面地理解文本内容。这种机制对于处理复杂的自然语言任务特别有用,因为它能有效地捕捉到文本中的长距离依赖关系,这对于理解语言的复杂性和细微差别至关重要

相关文章
|
8月前
|
机器学习/深度学习 自然语言处理 并行计算
大模型开发:什么是Transformer架构及其重要性?
Transformer模型革新了NLP,以其高效的并行计算和自注意力机制解决了长距离依赖问题。从机器翻译到各种NLP任务,Transformer展现出卓越性能,其编码器-解码器结构结合自注意力层和前馈网络,实现高效训练。此架构已成为领域内重要里程碑。
225 2
|
8月前
|
机器学习/深度学习 XML 自然语言处理
Transformer 架构—Encoder-Decoder
Transformer 架构—Encoder-Decoder
357 1
|
3月前
|
人工智能 测试技术 数据处理
首个Mamba+Transformer混合架构多模态大模型来了,实现单卡千图推理
【10月更文挑战第18天】《LongLLaVA: Scaling Multi-modal LLMs to 1000 Images Efficiently via Hybrid Architecture》提出了一种新型多模态大模型LongLLaVA,结合了Mamba和Transformer架构,通过系统优化实现在单张A100 80GB GPU上处理近千张图像的突破。该模型在视频理解、高分辨率图像分析和多模态智能体任务中表现出色,显著提升了计算效率。
179 64
|
1月前
|
机器学习/深度学习 编解码 人工智能
超越Transformer,全面升级!MIT等华人团队发布通用时序TimeMixer++架构,8项任务全面领先
一支由麻省理工学院、香港科技大学(广州)、浙江大学和格里菲斯大学的华人研究团队,开发了名为TimeMixer++的时间序列分析模型。该模型在8项任务中超越现有技术,通过多尺度时间图像转换、双轴注意力机制和多尺度多分辨率混合等技术,实现了性能的显著提升。论文已发布于arXiv。
165 83
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
大模型最强架构TTT问世!斯坦福UCSD等5年磨一剑, 一夜推翻Transformer
【7月更文挑战第21天】历经五年研发,斯坦福、UCSD等顶尖学府联合推出TTT架构,革新NLP领域。此架构以线性复杂度处理长序列,增强表达力及泛化能力,自监督学习下,测试阶段动态调整隐藏状态,显著提升效率与准确性。实验显示,TTT在语言模型与长序列任务中超越Transformer,论文详述于此:[https://arxiv.org/abs/2407.04620](https://arxiv.org/abs/2407.04620)。尽管如此,TTT仍需克服内存与计算效率挑战。
187 2
|
2月前
|
机器学习/深度学习 自然语言处理 计算机视觉
探索深度学习中的Transformer架构
探索深度学习中的Transformer架构
63 0
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
Tokenformer:基于参数标记化的高效可扩展Transformer架构
本文是对发表于arXiv的论文 "TOKENFORMER: RETHINKING TRANSFORMER SCALING WITH TOKENIZED MODEL PARAMETERS" 的深入解读与扩展分析。主要探讨了一种革新性的Transformer架构设计方案,该方案通过参数标记化实现了模型的高效扩展和计算优化。
238 0
|
4月前
|
机器学习/深度学习 存储 算法
Transformer、RNN和SSM的相似性探究:揭示看似不相关的LLM架构之间的联系
通过探索大语言模型(LLM)架构之间的潜在联系,我们可能开辟新途径,促进不同模型间的知识交流并提高整体效率。尽管Transformer仍是主流,但Mamba等线性循环神经网络(RNN)和状态空间模型(SSM)展现出巨大潜力。近期研究揭示了Transformer、RNN、SSM和矩阵混合器之间的深层联系,为跨架构的思想迁移提供了可能。本文深入探讨了这些架构间的相似性和差异,包括Transformer与RNN的关系、状态空间模型在自注意力机制中的隐含作用以及Mamba在特定条件下的重写方式。
190 7
Transformer、RNN和SSM的相似性探究:揭示看似不相关的LLM架构之间的联系
|
3月前
|
机器学习/深度学习 人工智能
【AI大模型】深入Transformer架构:编码器部分的实现与解析(下)
【AI大模型】深入Transformer架构:编码器部分的实现与解析(下)
|
5月前
|
机器学习/深度学习 自然语言处理 知识图谱

热门文章

最新文章