引言
在上一篇文章中,我们学习了自注意力机制,今天就来接着学习多头注意力机制。
多头注意力机制
多头注意力(Multi-Head Attention)是一种在Transformer模型中被广泛采用的注意力机制扩展形式,它通过并行地运行多个独立的注意力机制来获取输入序列的不同子空间的注意力分布,从而更全面地捕获序列中潜在的多种语义关联。
在多头注意力中,输入序列首先通过三个不同的线性变换层分别得到Query、Key和Value。然后,这些变换后的向量被划分为若干个“头”,每个头都有自己独立的Query、Key和Value矩阵。对于每个头,都执行一次Scaled Dot-Product Attention(缩放点积注意力)运算,即:
多头注意力的计算可以表示为:
最后,所有头的输出会被拼接(concatenate)在一起,然后再通过一个线性层进行融合,得到最终的注意力输出向量。
通过这种方式,多头注意力能够并行地从不同的角度对输入序列进行注意力处理,提高了模型理解和捕捉复杂依赖关系的能力。在实践中,多头注意力能显著提升Transformer模型在自然语言处理和其他序列数据处理任务上的性能。
多头注意力机制原理
输入变换与线性投影
多头注意力机制的输入变换与线性投影是其核心步骤之一。给定输入序列,首先通过三个不同的线性变换层生成查询(Query)、键(Key)和值(Value)矩阵。这些变换通常是通过全连接层实现的,其目的是将输入数据映射到不同的表示子空间中,为后续的注意力计算提供基础。
查询(Q)、键(K)和值(V)的生成
输入序列首先被映射到查询、键和值矩阵。这一步骤通过与权重矩阵WQ、WK和WV的矩阵乘法实现,其中每个矩阵都是模型中的可学习参数。数学上,这可以表示为:
线性投影的作用
线性投影不仅帮助模型将输入数据映射到不同的表示空间,而且还允许模型学习如何根据当前任务的需要动态地聚焦于输入数据的不同部分。这种动态聚焦是通过计算输入数据的加权表示来实现的,权重由模型学习得到。
分头计算与并行处理
多头注意力机制将查询、键和值矩阵分成多个头(即多个子空间),每个头具有不同的线性变换参数。每个头独立地计算注意力得分,并生成一个注意力加权后的输出。这些输出随后被合并,形成一个最终的、更复杂的表示。
分头计算
在多头注意力中,查询、键和值的线性变换实际上会进行多次,每个头都有自己的权重矩阵。这样,输入向量被分割到多个不同的子空间中,每个子空间执行自注意力操作。公式上表现为:
并行处理
由于每个头的计算是独立的,这些计算可以并行进行,从而提高模型的计算效率。这种并行性使得多头注意力机制在处理长序列数据时更加高效。
注意力权重计算
在多头注意力机制中,每个头的注意力权重计算是通过缩放点积注意力(Scaled Dot-Product Attention)实现的。具体来说,计算查询和键的点积,经过缩放、加上偏置后,使用softmax函数得到注意力权重。
缩放点积注意力
为了避免过大的点积导致梯度消失问题,通常会对点积结果进行缩放。缩放因子通常是键向量维度的倒数或平方根:
归一化注意力权重
使用softmax函数对缩放后的得分进行归一化,得到每个元素的注意力权重,这些权重之和为1:
拼接与融合
多头注意力机制的最后步骤是将所有头的输出拼接在一起,然后通过一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。
拼接
将所有头的输出拼接在一起,形成一个长向量。这一步骤整合了不同子空间学到的信息,增强模型的表达能力。
融合
对拼接后的向量进行一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。这一步骤对应着:
为什么使用多个注意力头?
1、增加模型的学习能力和表达力:通过多个注意力头,模型可以学习到更丰富的上下文信息,每个头可能关注输入的不同特征,这些特征综合起来可以更全面地理解和处理输入序列。
2、提高模型性能:实验证明,多头注意力机制相较于单头注意力,往往能带来性能提升。这是因为模型可以通过并行处理和集成多个注意力头的结果,从不同角度捕捉数据的多样性,增强了模型对复杂序列任务的理解和泛化能力。
多头自注意力(Multi-Head Self-Attention)的常见误区
多头自注意力(Multi-Head Self-Attention)是多头注意力的一种,都属于注意力机制在深度学习中的应用,尤其是自然语言处理(NLP)领域的Transformer模型中。
自注意力就是Q=K=V?
多头自注意力与多头注意力的区别
应用场景
多头注意力不仅限于自注意力场景,它可以应用于任何形式的注意力机制,包括但不限于跨序列的注意力,比如在一个序列上对另一个序列的注意力(Cross-Attention)。
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
# Linear layers for Q, K, V transformations
self.queries = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0]
query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]
# Transformations
queries = self.queries(query).view(N, query_len, self.num_heads, self.head_dim)
keys = self.keys(key).view(N, key_len, self.num_heads, self.head_dim)
values = self.values(value).view(N, value_len, self.num_heads, self.head_dim)
# Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Scaled dot-product attention
energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
out = out.view(N, query_len, self.embed_size)
return self.fc_out(out)
# Example usage for cross-sequence attention:
embed_size = 512
num_heads = 8
multi_head_attention = MultiHeadAttention(embed_size, num_heads)
# Dummy input tensors for demonstration
query = torch.rand((64, 10, embed_size)) # Target language sequence
key = torch.rand((64, 15, embed_size)) # Source language sequence
value = torch.rand((64, 15, embed_size)) # Source language sequence
output = multi_head_attention(query, key, value, mask=None)
print(output.shape) # Should print: torch.Size([64, 10, 512])
多头自注意力特指在同一序列内部,每个元素对其它所有元素的注意力机制进行了多头处理,用于捕获序列内元素间的复杂依赖关系,常见于Transformer的编码器和解码器中。
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
# Linear layers for Q, K, V transformations
self.queries = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
N = x.shape[0]
seq_len = x.shape[1]
# Transformations
queries = self.queries(x).view(N, seq_len, self.num_heads, self.head_dim)
keys = self.keys(x).view(N, seq_len, self.num_heads, self.head_dim)
values = self.values(x).view(N, seq_len, self.num_heads, self.head_dim)
# Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Scaled dot-product attention
energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
out = out.view(N, seq_len, self.embed_size)
return self.fc_out(out)
# Example usage for text comprehension:
embed_size = 512
num_heads = 8
multi_head_self_attention = MultiHeadSelfAttention(embed_size, num_heads)
# Dummy input tensor for demonstration
x = torch.rand((64, 20, embed_size)) # Input sequence (could be source language or target language)
output = multi_head_self_attention(x, mask=None)
print(output.shape) # Should print: torch.Size([64, 20, 512])
功能聚焦点
多头注意力可以用来同时考虑多种类型的关联性,无论是否是同一序列内的元素间相互作用。
多头自注意力特别强调的是序列自身的自参照特性,即序列的每一个位置都能查看整个序列并据此调整自身的表现形式。
凝练一下就是,多头注意力是一个通用术语,当应用于序列本身时,就成为多头自注意力。两者都是为了通过并行处理多个注意力视角来增强模型的表达能力和捕捉
多头自注意力计算过程
计算过程详细分的话,主要分为六个步骤:
输入向量
这是输入序列中第i个元素的向量表示,例如在自然语言处理中,它可以是一个词的嵌入向量。
线性变换
每个输入向量通过三组权重矩阵转换成查询(Query) 向量,键 (Key) 向量,和值 (Value) 向量。具体地,对于每个头 h:
这里是每个头特有的参数矩阵。
多头机制
如图所示,有两个头,因此对于每个输入向量,我们会得到两组向量。
注意力函数
每个头计算一个注意力得分,该得分决定了在计算输出时对每个值向量V的加权重要性。注意力得分通常使用缩放的点积注意力来计算:
其中是键向量的维度,用于缩放点积,以防止梯度消失。
头输出
每个头的输出是通过应用注意力函数得到的加权和:
输出连接与变换
最后,所有头的输出被连接在一起形成一个单一的长向量,然后通过另一个权重矩阵进行变换以产生最终的输出向量
这里的表示连接操作,是一个权重矩阵,用于将多头输出融合成最终的输出。
多头注意力及多头自注意力应用场景
多头注意力应用场景-机器翻译
假设你是一名国际商务顾问,经常需要处理来自不同国家的文件。今天你收到一份重要的合同,这份合同是用中文写的,而你需要将其准确无误地翻译成英文。为了确保翻译的质量,你决定使用一个基于多头注意力机制的机器翻译系统来辅助你完成这项任务。
在翻译过程中,这个系统不仅会逐字逐句地进行翻译,还会特别注意源语言(中文)句子中关键词汇、语法结构以及它们与目标语言(英文)对应部分的关系。比如,它会考虑“银行”这个词在不同的上下文中可能指代的是金融机构还是河岸,并据此选择最合适的翻译。
多头注意力的应用:
词汇层面:一个“头”可以专注于捕捉词汇层面的对应关系,确保每个词都被正确翻译。
语义层面:另一个“头”可能会更注重理解整体语义或上下文信息,从而帮助模型做出更明智的选择。
长距离依赖:还有些“头”能够识别并处理句子中相隔较远元素之间的联系,这对于理解复杂的句子结构至关重要。
下面是基于PyTorch框架实现的多头注意力机制的一个简化版本,特别适用于上述提到的机器翻译任务。在这个例子中,我们假定查询(query)、键(key)和值(value)分别来自目标语言和源语言序列。
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
# Linear layers for Q, K, V transformations
self.queries = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0]
query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]
# Transformations
queries = self.queries(query).view(N, query_len, self.num_heads, self.head_dim)
keys = self.keys(key).view(N, key_len, self.num_heads, self.head_dim)
values = self.values(value).view(N, value_len, self.num_heads, self.head_dim)
# Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Scaled dot-product attention
energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
out = out.view(N, query_len, self.embed_size)
return self.fc_out(out)
# Example usage for machine translation:
embed_size = 512
num_heads = 8
multi_head_attention = MultiHeadAttention(embed_size, num_heads)
# Dummy input tensors for demonstration
# Assume we have a batch of 64 sentences, each with up to 10 tokens in the target language (English),
# and up to 15 tokens in the source language (Chinese).
query = torch.rand((64, 10, embed_size)) # Target language sequence
key = torch.rand((64, 15, embed_size)) # Source language sequence
value = torch.rand((64, 15, embed_size)) # Source language sequence
output = multi_head_attention(query, key, value, mask=None)
print(output.shape) # Should print: torch.Size([64, 10, 512])
query代表目标语言(如英文)的嵌入表示,而key和value则来自源语言(如中文)。通过多头注意力机制,模型可以在翻译过程中同时关注源语言的不同方面,从而生成更准确、流畅的目标语言文本。这种机制对于提高机器翻译的质量非常关键,尤其是在处理复杂句子结构或具有多种含义的词汇时。
多头自注意力应用场景-阅读理解
此处还是以开头的看小说为例展开哈。想象你正在阅读一本小说,小说中的人物关系错综复杂,情节发展跌宕起伏。为了更好地理解和享受这本小说,你需要同时关注多个层面的信息:角色之间的对话、情感变化、故事情节的发展以及背景设定等。在这种情况下,你的大脑会自动选择性地关注某些重要信息,并忽略其他不重要的信息,以便更有效地处理和理解文本内容。
多头自注意力的应用:
局部细节:一个“头”可以专注于短语或句子内部的单词之间关系,帮助理解具体的语法结构。
全局上下文:另一个“头”可能会着眼于整个段落甚至章节之间的联系,以把握故事的整体走向。
人物互动:还有些“头”专门用于分析人物对话及其背后的情感色彩,有助于构建更加立体的角色形象。
下面是一个基于PyTorch框架实现的多头自注意力机制的简化版本。在这个例子中,查询(query)、键(key)和值(value)都来自同一个输入序列,即源语言序列本身。这种设计允许模型在同一位置上同时考虑序列中所有其他位置的信息,从而提高其捕捉长距离依赖关系的能力。
import torch
import torch.nn as nn
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
# Linear layers for Q, K, V transformations
self.queries = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.values = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
N = x.shape[0]
seq_len = x.shape[1]
# Transformations
queries = self.queries(x).view(N, seq_len, self.num_heads, self.head_dim)
keys = self.keys(x).view(N, seq_len, self.num_heads, self.head_dim)
values = self.values(x).view(N, seq_len, self.num_heads, self.head_dim)
# Transpose for attention dot product: (N, num_heads, seq_len, head_dim)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Scaled dot-product attention
energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous()
out = out.view(N, seq_len, self.embed_size)
return self.fc_out(out)
# Example usage for text comprehension:
embed_size = 512
num_heads = 8
multi_head_self_attention = MultiHeadSelfAttention(embed_size, num_heads)
# Dummy input tensor for demonstration
# Assume we have a batch of 64 sentences, each with up to 20 tokens.
x = torch.rand((64, 20, embed_size)) # Input sequence (could be source language or target language)
output = multi_head_self_attention(x, mask=None)
print(output.shape) # Should print: torch.Size([64, 20, 512])
在这个代码示例中,x代表输入序列(可以是源语言或目标语言),并且查询(query)、键(key)和值(value)都是从同一个输入序列中派生出来的。通过多头自注意力机制,模型可以在处理每个位置时同时考虑该位置与其他所有位置之间的关系,从而更全面地理解文本内容。这种机制对于处理复杂的自然语言任务特别有用,因为它能有效地捕捉到文本中的长距离依赖关系,这对于理解语言的复杂性和细微差别至关重要