PyTorch快餐教程2019 (2) - Multi-Head Attention

简介: # PyTorch快餐教程2019 (2) - Multi-Head Attention 上一节我们为了让一个完整的语言模型跑起来,可能给大家带来的学习负担过重了。没关系,我们这一节开始来还上节没讲清楚的债。 还记得我们上节提到的两个Attention吗? ![两种Attention机制](https://upload-images.jianshu.io/upload_images/

PyTorch快餐教程2019 (2) - Multi-Head Attention

上一节我们为了让一个完整的语言模型跑起来,可能给大家带来的学习负担过重了。没关系,我们这一节开始来还上节没讲清楚的债。

还记得我们上节提到的两个Attention吗?
两种Attention机制

上节我们给大家一个印象,现在我们正式开始介绍其原理。

Scaled Dot-Product Attention

首先说Scaled Dot-Product Attention,其计算公式为:
$
Attention(Q,K,V)=softmax(frac{QK^T}{sqrt{d_k}})V
$

Q乘以K的转置,再除以$d_k$的平方根进行缩放,经过一个可选的Mask,经过softmax之后,再与V相乘。
用代码实现如下:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

Multi-Head Attention

有了缩放点积注意力机制之后,我们就可以来定义多头注意力。

$
MultiHead(Q,K,V)=concat(head_1,...,head_n)W^O
$
其中,$head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)$
这个Attention是我们上面介绍的Scaled Dot-Product Attention.

这些W都是要训练的参数矩阵。
$
W_i^Qin mathbb{R}^{d_{model} times d_k},
W_i^Kinmathbb{R}^{d_{model} times d_k}, W_i^Vinmathbb{R}^{d_{model} times d_v}, W_oinmathbb{R}^{hd_v times d_{model}}
$
h是multi-head中的head数。在《Attention is all you need》论文中,h取值为8。
$d_k=d_v=d_{model}/h=64$
这样我们需要的参数就是d_model和h.

大家看公式有点要晕的节奏,别怕,我们上代码:

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "初始化时指定头数h和模型维度d_model"
        super(MultiHeadedAttention, self).__init__()
        # 二者是一定整除的
        assert d_model % h == 0
        # 按照文中的简化,我们让d_v与d_k相等
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

其中,clones是复制几个一模一样的模型的函数,其定义如下:

def clones(module, N):
    "生成n个相同的层"
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

Attention的逻辑主要分为4步。第一步是计算一下mask。

    def forward(self, query, key, value, mask=None):
        "实现多头注意力模型"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

第二步是将这一批次的数据进行变形 d_model => h x d_k

        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

第三步,针对所有变量计算scaled dot product attention

        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)

最后,将attention计算结果串联在一起,其实对张量进行一次变形:

        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

再看一种写法巩固一下

上面引用的代码来自:http://nlp.seas.harvard.edu/2018/04/03/attention.html

为了加深印象,我们再看另一种写法。
这个的命名更偏工程,d_model叫做hid_dim,h叫做n_heads,但是意思是一回事。

class SelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        
        # d_model // h 仍然是要能整除,换个名字仍然意义不变
        assert hid_dim % n_heads == 0

        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)

        self.fc = nn.Linear(hid_dim, hid_dim)

        self.do = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)

下面是处理数据的过程:

    def forward(self, query, key, value, mask=None):

# Q,K,V计算与变形:

        bsz = query.shape[0]

        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)

        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)

# Q, K相乘除以scale,这是计算scaled dot product attention的第一步

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

# 如果没有mask,就生成一个

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

# 然后对Q,K相乘的结果计算softmax加上dropout,这是计算scaled dot product attention的第二步:

        attention = self.do(torch.softmax(energy, dim=-1))

# 第三步,attention结果与V相乘

        x = torch.matmul(attention, V)

# 最后将多头排列好,就是multi-head attention的结果了

        x = x.permute(0, 2, 1, 3).contiguous()

        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))

        x = self.fc(x)

        return x

第二种实现取自:https://github.com/bentrevett/pytorch-seq2seq

目录
相关文章
|
2月前
|
存储 物联网 PyTorch
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
**Torchtune**是由PyTorch团队开发的一个专门用于LLM微调的库。它旨在简化LLM的微调流程,提供了一系列高级API和预置的最佳实践
231 59
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
|
2月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
72 5
|
5月前
|
并行计算 Ubuntu PyTorch
Ubuntu下CUDA、Conda、Pytorch联合教程
本文是一份Ubuntu系统下安装和配置CUDA、Conda和Pytorch的教程,涵盖了查看显卡驱动、下载安装CUDA、添加环境变量、卸载CUDA、Anaconda的下载安装、环境管理以及Pytorch的安装和验证等步骤。
981 1
Ubuntu下CUDA、Conda、Pytorch联合教程
|
8月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十八)(1)
PyTorch 2.2 中文官方教程(十八)
255 2
PyTorch 2.2 中文官方教程(十八)(1)
|
8月前
|
并行计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(4)
PyTorch 2.2 中文官方教程(十七)
253 2
PyTorch 2.2 中文官方教程(十七)(4)
|
8月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十九)(1)
PyTorch 2.2 中文官方教程(十九)
154 1
PyTorch 2.2 中文官方教程(十九)(1)
|
8月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
107 1
PyTorch 2.2 中文官方教程(十八)(3)
|
8月前
|
API PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(2)
PyTorch 2.2 中文官方教程(十八)
218 1
PyTorch 2.2 中文官方教程(十八)(2)
|
8月前
|
异构计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(3)
PyTorch 2.2 中文官方教程(十七)
108 1
PyTorch 2.2 中文官方教程(十七)(3)
|
8月前
|
PyTorch 算法框架/工具 机器学习/深度学习
PyTorch 2.2 中文官方教程(十七)(2)
PyTorch 2.2 中文官方教程(十七)
173 1
PyTorch 2.2 中文官方教程(十七)(2)