Bert Pytorch 源码分析:二、注意力层

简介: Bert Pytorch 源码分析:二、注意力层
# 注意力机制的具体模块
# 兼容单头和多头
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """
  # QKV 尺寸都是 BS * ML * ES
  # (或者多头情况下是 BS * HC * ML * HS,最后两维之外的维度不重要)
  # 从输入计算 QKV 的过程可以统一处理,不必放到每个头里面
    def forward(self, query, key, value, mask=None, dropout=None):
    # 将每个批量的 Q 和 K.T 做矩阵乘法,再除以√ES,
    # 得到相关性矩阵 S,尺寸为 BS * ML * ML
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))
    # 如果存在掩码则使用它
    # 将 scores 的 mask == 0 的位置上的元素改为 -1e9
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # 将 S 转换到概率空间,同时对其最后一维归一化
        p_attn = F.softmax(scores, dim=-1)
    # 如果存在 dropout 则使用
        if dropout is not None:
            p_attn = dropout(p_attn)
    # 最后将 S 与 V 相乘得到输出
        return torch.matmul(p_attn, value), p_attn
# 多头注意力就是包含很多(HC)个头,但是每个头的尺寸(HS)变为原来的 1/HC
# 把 qkv 切成小段分给每个头做运算,将结果拼起来作为整个层的输出
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """
  # h 是头数(HC)
  # d_model 是嵌入向量大小(ES)
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
    # 判断 ES 是否能被 HC 整除,以便结果能拼接回去
        assert d_model % h == 0
    # d_k 是每个头的大小 HS = ES // HC
        self.d_k = d_model // h
        self.h = h
    # 创建输入转换为QKV的权重矩阵,Wq, Wk, Wv,尺寸均为 ES * ES
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
    # 输出应该还乘一个权重矩阵,Wo,尺寸也是 ES * ES
        self.output_linear = nn.Linear(d_model, d_model)
    # 创建执行注意力机制的具体模块
        self.attention = Attention()
    # 创建 droput 层
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
    # 获取批量大小(BS)
        batch_size = query.size(0)
    '''
        query, key, value = [
      l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        for l, x in zip(self.linear_layers, (query, key, value))
    ]
    '''
    # 将 QKV 的每个与其相应权重矩阵 Wq, Wk, Wv 相乘
    lq, lk, lv = self.linear_layers
    query, key, value = lq(query), lk(key), lv(value) 
    # 然后将他们转型为 BS * ML * HC * HS
    # 也就是将最后一个维度按头部数量分割成小的向量
    query, key, value = [
      x.view(batch_size, -1, self.h, self.d_k)
      for x in (query, key, value)
    ]
    # 然后交换 1 和 2 维,变成 BS * HC * ML  * HS
    # 这样每个头的 QKV 是内存连续的,便于矩阵相乘
    query, key, value = [
      x.transpose(1, 2)
      for x in (query, key, value)
    ]
        # 对每个头应用注意力机制,输出尺寸不变
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        # 交换 1 和 2 维恢复原状,然后把每个头的输出相连接,尺寸变为 BS * ML * ES
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
    # 执行最后的矩阵相乘
        return self.output_linear(x)

缩写表

  • BS:批量大小,即一批数据中样本大小,训练集和测试集可能不同,那就是TBS和VBS
  • ES:嵌入大小,嵌入向量空间的维数,也是注意力层的隐藏单元数量,GPT 中一般是 768
  • ML:输入序列最大长度,一般是512或者1024,不够需要用<pad>填充
  • HC:头部的数量,需要能够整除ES,因为每个头的输出拼接起来才是层的输出
  • HS:头部大小,等于ES // HC
  • VS:词汇表大小,也就是词的种类数量

尺寸备注

  • 嵌入层的矩阵尺寸应该是VS * ES
  • 注意力层的输入尺寸是BS * ML * ES
  • 输出以及 Q K V 和输入形状相同
  • 每个头的 QKV 尺寸为BS * ML * HS
  • 权重矩阵尺寸为ES * ES
  • 相关矩阵 S 尺寸为BS * ML * ML
相关文章
|
8月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
965 4
|
8月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
129 0
|
8月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:四、编解码器
Bert Pytorch 源码分析:四、编解码器
98 0
|
8月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
82 0
|
15天前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch FlexAttention技术实践:基于BlockMask实现因果注意力与变长序列处理
本文介绍了如何使用PyTorch 2.5及以上版本中的FlexAttention和BlockMask功能,实现因果注意力机制与填充输入的处理。通过attention-gym仓库安装相关工具,并详细展示了MultiheadFlexAttention类的实现,包括前向传播函数、因果掩码和填充掩码的生成方法。实验设置部分演示了如何组合这两种掩码并应用于多头注意力模块,最终通过可视化工具验证了实现的正确性。该方法适用于处理变长序列和屏蔽未来信息的任务。
54 17
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
CNN中的注意力机制综合指南:从理论到Pytorch代码实现
注意力机制已成为深度学习模型的关键组件,尤其在卷积神经网络(CNN)中发挥了重要作用。通过使模型关注输入数据中最相关的部分,注意力机制显著提升了CNN在图像分类、目标检测和语义分割等任务中的表现。本文将详细介绍CNN中的注意力机制,包括其基本概念、不同类型(如通道注意力、空间注意力和混合注意力)以及实际实现方法。此外,还将探讨注意力机制在多个计算机视觉任务中的应用效果及其面临的挑战。无论是图像分类还是医学图像分析,注意力机制都能显著提升模型性能,并在不断发展的深度学习领域中扮演重要角色。
163 10
|
3月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
216 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
165 0
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
|
6月前
|
机器学习/深度学习 数据采集 自然语言处理
注意力机制中三种掩码技术详解和Pytorch实现
**注意力机制中的掩码在深度学习中至关重要,如Transformer模型所用。掩码类型包括:填充掩码(忽略填充数据)、序列掩码(控制信息流)和前瞻掩码(自回归模型防止窥视未来信息)。通过创建不同掩码,如上三角矩阵,模型能正确处理变长序列并保持序列依赖性。在注意力计算中,掩码修改得分,确保模型学习的有效性。这些技术在现代NLP和序列任务中是核心组件。**
324 12
|
8月前
|
机器学习/深度学习 自然语言处理 PyTorch
Pytorch图像处理注意力机制SENet CBAM ECA模块解读
注意力机制最初是为了解决自然语言处理(NLP)任务中的问题而提出的,它使得模型能够在处理序列数据时动态地关注不同位置的信息。随后,注意力机制被引入到图像处理任务中,为深度学习模型提供了更加灵活和有效的信息提取能力。注意力机制的核心思想是根据输入数据的不同部分,动态地调整模型的注意力,从而更加关注对当前任务有用的信息。
436 0

热门文章

最新文章