Pytorch CIFAR10图像分类 Swin Transformer篇(一)

简介: Pytorch CIFAR10图像分类 Swin Transformer篇(一)

再次介绍一下我的专栏,很适合大家初入深度学习或者是Pytorch和Keras,希望这能够帮助初学深度学习的同学一个入门Pytorch或者Keras的项目和在这之中更加了解Pytorch&Keras和各个图像分类的模型

他有比较清晰的可视化结构和架构,除此之外,我是用jupyter写的,所以说在文章整体架构可以说是非常清晰,可以帮助你快速学习到各个模块的知识,而不是通过 python脚本Q一行一行的看,这样的方式是符合初学者的。

除此之外,如果你需要变成脚本形式,也是很简单的。

这里贴一下汇总篇: 汇总篇

4.定义网络(Swin Transformer)

自从Transformer在NLPQ任务上取得突破性的进展之后,业内一直尝试着把Transformer用于在CV领域。之前的若干尝试,例如iGPT,ViT都是将Transformer用在了图像分类领域,ViT我们之前也有介绍过在图像分类上的方法,但目前这些方法都有两个非常严峻的问题

1.受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是Transformer的天生缺陷

2.目前的基于Transformer框架更多的是用来进行图像分类,理论上来进解决检测问题应该也比较容易,但是对实例分割这种密集预测的场景Transformer并不擅长解决。

而这篇微软亚洲研究院提出的的Swin Transformer解决了这两个问题,并且在分类,检测,分割任务上都取得了SOTA的效果,同时获得了ICCV2021的best paper。Swin Transformer的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的backbone,并且大多数在CNN网络中常见的超参数在Swin Transformer中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。该网络架构的设计非常巧妙,是一个非常精彩的将Transformer应用到图像领域的结构,值得每个AI领域的人前去学习。

fcf37b4f40ea672d7c83e6795c31b307.png

实际上的,Swin Transformer 是在 Vision Transformer 的基础上使用滑动窗口 (shifted windows,SW)进行改造而来。它将 Vision Transformer 中固定大小的采样快按照层次分成不同大小的块(Windows),每一个块之间的信息并不共通、独立运算从而大大提高了计算效率。从 SwinTransformer 的架构图中可以看出其与 Vision Transformer 的结构很相似,不同的点在于其采用的Transformer Block 是由两个连续的 Swin Transformer Block 构成的,这两个 Block 块与 VisionTransformer中的 Block 块大致相同,只是将 Multi-head Self-Attention (MSA) 替换成了含有不同大小Windows 的 W-MSA与SW-MAS (具有滑动窗.SW),通过 Windows 和 Shifted Windows 的Multi-head Self-Attention 提高运算效率并最终提高分类的准确率。

Swin Transformer整体架构

从 Swin Transformer 网络的整体框架图我们可以看到,首先将输入图像1输入到 Patch Partition 进行一个分块操作,这一部分其实是和VIT是一样的,然后送入 Linear Embedding 模块中进行通道数channel 的调整。最后通过 stage 1,2,3 和 4 的特征提取和下采样得到最终的预测结果,值得注意的是每经过一个 stage,size 就会缩小为原来的 1/2,channel 就会扩大为原来的 2倍与resnet 网络类似。每个 stage 中的 Swin Transformer Block 都由两个相连的分别以 W-MSA和 SW-MSA为基础的 Transformer Block 构成,通过 Window 和 Shifted Window 机制提高计算性能。最右边两个图为Swim Transformer的每个块结构,类似于ViT的块结构,其核心修改的地方就是将原本的MSA变为WMSA。


6e5a7ae0e7962d9c91edd9d6fba8d0af.png

Patch Merging

Patch Merging 模块将 尺寸为 H X W 的 Patch 块首先进行拼接并在 channel 维度上进行concatenate 构成了 H/2 x W/2  4C 的特征图,然后再进行 Layer Normalization 操作进行正则化,然后通过一个 Linear 层形成了一个 H/2 x W/2  2C ,完成了特征图的下采样过程。其中size 缩小为原来的 1/2,channel 扩大为原来的 2倍。

这里也可以举个例子,假设我们的输入是4x4大小单通道的特征图,首先我们会隔一个取一个小Patch组合在一起,最后4x4的特征图会行成4个2x2的特征图。接下来将4个Patch进行拼接,现在得到的特征图尺寸为2x2x4。然后会经过一个LN层,这里当然会改变特征图的值,我改变了一些颜色象征性的表示了一下,LN层后特征图尺寸不会改变,仍为2x2x4。最后会经过一个全连接层,将特征图尺寸由2x2x4变为2x2x2。

576454ee9c0d8f547f0fa8a410791143.png

W-MSA

ViT 网络中的 MSA通过 self-Attention 使得每一个像素点都可以和其他的像素点进行内积从而得到所有像素点的信息,从而获得丰富的全局信息。但是每个像素点都需要和其他像素点进行信息交换计算量巨大,网络的执行效率低下。因此 Swin-T 将 MSA 分个多个固定的 Windows 构成了 W-MSA,每个 Windows 之间的像素点只能与该 Windows 中的其他像素点进行内积从而获得信息,这样便大幅的减小了计算量,提高了网络的运算效率。

48b6643afd70239ce1596efe96e2eb24.png

MSA和 W-MAS 的计算量如下所示

image.png

其中 h、w 和 C 分别代表特征图的高度、!宽度和深度,M 代表每个 Windows 的大小。

假定h =w =64,M =4,C = 96

采用MSA模块的计算复杂度为 4 x 64 x 64  962 2  (64  64)2  96 = 3372220416采用W-MSA模块的计算复杂度为 4 x 64 x 6  962 -2  4  64  64  96 = 163577856可以计算出 W-MSA 节省了3208642560 FLOPs。

SW-MSA

虽然 W-MSA 通过划分 Windows 的方法减少了计算量,但是由于各个 Windows 之间无法进行信息的交互,因此可以看作其“感受野”缩小,无法得到较全局准确的信息从而影响网络的准确度。为了实现不同窗口之间的信息交互,我们可以将窗口滑动,偏移窗口使其包合不同的像素点,然后再进行 W.MSA计算,将两次 W-MSA计算的结果进行连接便可结合两个不同的 Windows 中的像素点所包含的信息从而实现 Windows 之间的信息共通。

偏移窗口的 W-MSA构成了 SW-MSA 模块,其 Windows 在 W-MSA的基础上向右下角偏移了两个Patch,形成了9个大小不一的块,然后使用 cyclic shift 将这9 个块平移拼接成与 W-MSA 对应的4个大小相同的块,再通过 masked MSA 对这 4 个拼接块进行对应的模板计算完成信息的提取,最后通过 reverse cyclic shift 将信息数据 patch 平移回原先的位置。通过 SW-MSA机制完成了偏移窗口的象素点的 MSA计算并实现了不同窗口间像素点的信息交流,从而间接扩大了网络的“感受野”,提高了信息的利用效率


012841c2c86adc89f015e92f748858ca.png

我们仔细说明一下这一部分,上面可能比较抽象,这一块我认为也是Swin Transformer的核心。可以发现通过将窗口进行偏移后,就到达了窗口与窗口之间的相互通信。虽然已经能达到窗口与窗口之间的通信,但是原来的特征图只有4个窗口,经过移动窗口后,得到了9个窗口,窗口的数量有所增加并且9个窗口的大小也不是完全相同,这就导致计算难度增加。因此,作者又提出而了Efficient batchcomputation for shifted configuration,一种更加高效的计算方法。如下图所示:


1f3e56c0623906b74215752f5db459c1.png

先将012区域移动到最下方,再将360区域移动到最右方,此时移动完后,4是一个单独的窗口;将5和3合并成一个窗口;7和1合并成一个窗口; 8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口了,所以能够保证计算量是一样的。这里肯定有人会想,把不同的区域合并在一起(比如5和3)进行MSA,这信息不就乱窜了吗? 是的,为了防止这个问题,在实际计算中使用的是maskedMSA即带蒙板mask的MSA,这样就能够通过设置Mask来隔绝不同区域的信息了

Relative position bias

Swin-T 网络还在 Attention 计算中引入了相对位置偏置机制去提高网络的整体准确率表现,通过引入相对位置偏置机制,其准确度能够提高 1.2%~2.3% 不等。以 2x2 的特征图为例,首先我们需要对特征图的各个块进行绝对位置的编号,得到每个块的绝对位置索引。然后对每个块计算其与其他块之间的相对位置,计算方法为该块的绝对位置索引减去其他块的绝对位置索引,可以得到每个块的相对位置索引矩阵。将每个块的相对位置索引矩阵展平连接构成了整个特征图的相对位置索引矩阵,具体的计算流程如下图所示。

fcc0f73225b1f8d757b61a8926d1b245.png

Swin-T并不是使用二维元组形式的相对位置索引矩阵,而是通过将二维元组形式的相对位置索引映射为一维的相对位置偏置(Relative position bias) 来构成相应的矩阵,具体的映射方法如下: 1.将对应的相对位置行索引和列索引分别加上 M-1,2.将行索引乘以 2M-1,3.将行索引和列索引相加,再使用对应的相对位置偏置表(Relative position bias table) 进行映射即可得到最终的相对位置偏置B。具体的计算流程如下所示

8e8f49bfe5aade3f2e8d3043843ea27a.png

如果这一部分看的比较迷糊,也可以简单看看直接从相对位置进行映射,我们就可以得到相对位置偏置



a292fd49e810881e1a4f3426e0ef88d0.png

加入了相对位置偏置机制的 Attention 计算公式如下所示

image.png

其中 B 即为上述计算得到的相对位置偏置

Swin Transformer 网络结构

下表是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:

win.sz.7x7表示使用的窗 (Windows) 的大小

dim表示feature map的channel深度 (或者说token的向量长度)

head表示多头注意力模块中head的个数

37ee2eed74e2830e747dbdceb03456e8.png

首先我们还是得判断是否可以利用GPU,因为GPU的速度可能会比我们用CPU的速度快20-50倍左右,特别是对卷积神经网络来说,更是提升特别明显。

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Patch Embedding

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。

具体做法是对原始图片裁成一个个 patch_size * patch_size 的窗口大小,然后进行嵌入。

这里可以通过二维卷积层,将stride,kernelsize设置为patch size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度

class PatchEmbed(nn.Module):
    def __init__(self,
                 patch_size=4,
                 in_c=3,
                 embed_dim=96,
                 norm_layer=None):
        super(PatchEmbed, self).__init__()
        self.patch_size = patch_size
        self.in_c = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(
            in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    def forward(self, x):
        # 如果图片的H,W不是patch_size的整数倍,需要padding
        _, _, h, w = x.shape
        if (h % self.patch_size != 0) or (w % self.patch_size != 0):
            x = F.pad(x, (0, self.patch_size - w % self.patch_size,
                          0, self.patch_size - h % self.patch_size,
                          0, 0))
        x = self.proj(x)
        _, _, h, w = x.shape
        # (b,c,h,w) -> (b,c,hw) -> (b,hw,c)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, h, w

Patch Merging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。


在CNN中,则是在每个Stage开始前用 stride=2 的卷积/池化层来降低分辨率

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍 (因为H.W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍

class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super(PatchMerging, self).__init__()
        self.dim = dim
        self.reduction = nn.Linear(4*dim, 2*dim, bias=False)
        self.norm = norm_layer(4*dim)
    def forward(self, x, h, w):
        # (b,hw,c)
        b, l, c = x.shape
        # (b,hw,c) -> (b,h,w,c)
        x = x.view(b, h, w, c)
        # 如果h,w不是2的整数倍,需要padding
        if (h % 2 == 1) or (w % 2 == 1):
            x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
        # (b,h/2,w/2,c)
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        # (b,h/2,w/2,c)*4 -> (b,h/2,w/2,4c)
        x = torch.cat([x0, x1, x2, x3], -1)
        # (b,hw/4,4c)
        x = x.view(b, -1, 4*c)
        x = self.norm(x)
        # (b,hw/4,4c) -> (b,hw/4,2c)
        x = self.reduction(x)
        return x

下面是一个示意图 (输入张量N=1,H=W=8,C=1,不包含最后的全连接层调整)

aec06139eadcd42f39f4831002142f35.png

Window Partition/Reverse

window partition 函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W c,划分成num_windows*B,window_size, window_size, C,其中 num windows = H*W/window size*window size),即窗口的个数。而window reverse 函数则是对应的逆过程。这两个函数会在后面的 window Attention 用到。

def window_partition(x, window_size):
    """
    将feature map按照window_size分割成windows
    """
    b, h, w, c = x.shape
    # (b,h,w,c) -> (b,h//m,m,w//m,m,c)
    x = x.view(b, h//window_size, window_size, w//window_size, window_size, c)
    # (b,h//m,m,w//m,m,c) -> (b,h//m,w//m,m,m,c)
    # -> (b,h//m*w//m,m,m,c) -> (b*n_windows,m,m,c)
    windows = (x
               .permute(0, 1, 3, 2, 4, 5)
               .contiguous()
               .view(-1, window_size, window_size, c))
    return windows
def window_reverse(x,window_size,h,w):
    """
    将分割后的windows还原成feature map
    """
    b = int(x.shape[0] / (h*w/window_size/window_size))
    # (b,h//m,w//m,m,m,c)
    x = x.view(b,h//window_size,w//window_size,window_size,window_size,-1)
    # (b,h//m,w//m,m,m,c) -> (b,h//m,m,w//m,m,c)
    # -> (b,h,w,c)
    x = x.permute(0,1,3,2,4,5).contiguous().view(b,h,w,-1)
    return x
class MLP(nn.Module):
    def __init__(self,
                 in_features,
                 hid_features=None,
                 out_features=None,
                 dropout=0.):
        super(MLP, self).__init__()
        out_features = out_features or in_features
        hid_features = hid_features or in_features
        self.fc1 = nn.Linear(in_features, hid_features)
        self.act = nn.GELU()
        self.drop1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hid_features, out_features)
        self.drop2 = nn.Dropout(dropout)
    def forward(self, x):
        x = self.drop1(self.act(self.fc1(x)))
        x = self.drop2(self.fc2(x))
        return x

Window Attention

这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

我们先简单看下公式

image.png

主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。

class WindowAttention(nn.Module):
    def __init__(self,
                 dim,
                 window_size,
                 n_heads,
                 qkv_bias=True,
                 attn_dropout=0.,
                 proj_dropout=0.):
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size
        self.n_heads = n_heads
        self.scale = (dim // n_heads) ** -.5
        # ((2m-1)*(2m-1),n_heads)
        # 相对位置参数表长为(2m-1)*(2m-1)
        # 行索引和列索引各有2m-1种可能,故其排列组合有(2m-1)*(2m-1)种可能
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size - 1) * (2*window_size - 1),
                        n_heads))
        # 构建窗口的绝对位置索引
        # 以window_size=2为例
        # coord_h = coord_w = [0,1]
        # meshgrid([0,1], [0,1])
        # -> [[0,0], [[0,1]
        #     [1,1]], [0,1]]
        # -> [[0,0,1,1],
        #     [0,1,0,1]]
        # (m,)
        coord_h = torch.arange(self.window_size)
        coord_w = torch.arange(self.window_size)
        # (m,)*2 -> (m,m)*2 -> (2,m,m)
        coords = torch.stack(torch.meshgrid([coord_h, coord_w]))
        # (2,m*m)
        coord_flatten = torch.flatten(coords, 1)
        # 构建窗口的相对位置索引
        # (2,m*m,1) - (2,1,m*m) -> (2,m*m,m*m)
        # 以coord_flatten为
        # [[0,0,1,1]
        #  [0,1,0,1]]为例
        # 对于第一个元素[0,0,1,1]
        # [[0],[0],[1],[1]] - [[0,0,1,1]]
        # -> [[0,0,0,0] - [[0,0,1,1] = [[0,0,-1,-1]
        #     [0,0,0,0]    [0,0,1,1]    [0,0,-1,-1]
        #     [1,1,1,1]    [0,0,1,1]    [1,1, 0, 0]
        #     [1,1,1,1]]   [0,0,1,1]]   [1,1, 0, 0]]
        # 相当于每个元素的h减去每个元素的h
        # 例如,第一行[0,0,0,0] - [0,0,1,1] -> [0,0,-1,-1]
        # 即为元素(0,0)相对(0,0)(0,1)(1,0)(1,1)为列(h)方向的差
        # 第二个元素即为每个元素的w减去每个元素的w
        # 于是得到窗口内每个元素相对每个元素高和宽的差
        # 例如relative_coords[0,1,2]
        # 即为窗口的第1个像素(0,1)和第2个像素(1,0)在列(h)方向的差
        relative_coords = coord_flatten[:, :, None] - coord_flatten[:, None, :]
        # (m*m,m*m,2)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # 论文中提到的,将二维相对位置索引转为一维的过程
        # 1. 行列都加上m-1
        # 2. 行乘以2m-1
        # 3. 行列相加
        relative_coords[:, :, 0] += self.window_size - 1
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        # (m*m,m*m,2) -> (m*m,m*m)
        relative_pos_idx = relative_coords.sum(-1)
        self.register_buffer('relative_pos_idx', relative_pos_idx)
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_dropout = nn.Dropout(proj_dropout)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x, mask):
        b, n, c = x.shape
        # (b*n_windows,m*m,total_embed_dim)
        # -> (b*n_windows,m*m,3*total_embed_dim)
        # -> (b*n_windows,m*m,3,n_heads,embed_dim_per_head)
        # -> (3,b*n_windows,n_heads,m*m,embed_dim_per_head)
        qkv = (self.qkv(x)
               .reshape(b, n, 3, self.n_heads, c//self.n_heads)
               .permute(2, 0, 3, 1, 4))
        # (b*n_windows,n_heads,m*m,embed_dim_per_head)
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        # (b*n_windows,n_heads,m*m,m*m)
        attn = (q @ k.transpose(-2, -1))
        # (m*m*m*m,n_heads)
        # -> (m*m,m*m,n_heads)
        # -> (n_heads,m*m,m*m)
        # -> (b*n_windows,n_heads,m*m,m*m) + (1,n_heads,m*m,m*m)
        # -> (b*n_windows,n_heads,m*m,m*m)
        relative_pos_bias = (self.relative_position_bias_table[self.relative_pos_idx.view(-1)]
                             .view(self.window_size*self.window_size, self.window_size*self.window_size, -1))
        relative_pos_bias = relative_pos_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_pos_bias.unsqueeze(0)
        if mask is not None:
            # mask: (n_windows,m*m,m*m)
            nw = mask.shape[0]
            # (b*n_windows,n_heads,m*m,m*m)
            # -> (b,n_windows,n_heads,m*m,m*m)
            # + (1,n_windows,1,m*m,m*m)
            # -> (b,n_windows,n_heads,m*m,m*m)
            attn = (attn.view(b//nw, nw, self.n_heads, n, n)
                    + mask.unsqueeze(1).unsqueeze(0))
            # (b,n_windows,n_heads,m*m,m*m)
            # -> (b*n_windows,n_heads,m*m,m*m)
            attn = attn.view(-1, self.n_heads, n, n)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_dropout(attn)
        # (b*n_windows,n_heads,m*m,embed_dim_per_head)
        # -> (b*n_windows,m*m,n_heads,embed_dim_per_head)
        # -> (b*n_windows,m*m,total_embed_dim)
        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_dropout(x)
        return x

首先输入张量形状为 numwindows*B,window_size * window size,C然后经过 self.qkv 这个全连接层后,进行reshape,调整轴的顺序,得到形状为 3,numwindows*B,num heads,window_size*window_size, c//num heads ,并分配给q,k,v。根据公式,我们对g 乘以一个 scale 缩放系数,然后与 (为了满足矩阵乘要求,需要将最后两维度调换) 进行相乘。得到形状为 (numwindows*B,num heads,window size*window size,vindow size*window size)的 attn 张量

之前我们针对位置编码设置了个形状为(2*window size-1*2*window size-1,numHeads)的可学习变量。我们用计算得到的相对编码位置索引 self.relative_position_index 选取,得到形状为(window_size*window_size,window_size*window size,numHeads)的编码,加到attn张量上暂不考虑mask的情况,剩下就是跟transformer一样的softmax,dropout,与v矩阵乘,再经过-吴全连接层和dropout

Pytorch CIFAR10图像分类 Swin Transformer篇(二):https://developer.aliyun.com/article/1410618

相关文章
|
3月前
|
机器学习/深度学习 自然语言处理 PyTorch
Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现
掩码注意力是生成模型的核心,通过上三角掩码限制模型仅关注当前及之前token,确保自回归因果性。相比BERT的双向注意力,它实现单向生成,是GPT等模型逐词预测的关键机制,核心仅需一步`masked_fill_`操作。
343 0
Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
编码器-解码器架构详解:Transformer如何在PyTorch中工作
本文深入解析Transformer架构,结合论文与PyTorch源码,详解编码器、解码器、位置编码及多头注意力机制的设计原理与实现细节,助你掌握大模型核心基础。建议点赞收藏,干货满满。
939 3
|
2月前
|
机器学习/深度学习 自然语言处理 监控
23_Transformer架构详解:从原理到PyTorch实现
Transformer架构自2017年Google发表的论文《Attention Is All You Need》中提出以来,彻底改变了深度学习特别是自然语言处理领域的格局。在短短几年内,Transformer已成为几乎所有现代大型语言模型(LLM)的基础架构,包括BERT、GPT系列、T5等革命性模型。与传统的RNN和LSTM相比,Transformer通过自注意力机制实现了并行化训练,极大提高了模型的训练效率和性能。
|
7月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
7月前
|
算法 PyTorch 算法框架/工具
昇腾910-PyTorch 实现 Vggnet图像分类
本实验基于昇腾平台,使用PyTorch实现Vggnet模型对CIFAR10数据集进行图像分类。内容涵盖Vggnet模型创新点(小卷积核堆叠、深层网络结构)、网络架构剖析及代码实战分析。通过定义`blockVGG`函数构建卷积块,实现VGG11网络,并结合数据预处理、训练与测试模块完成分类任务。实验展示了深度学习中增加网络深度对性能提升的重要性。
|
11月前
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 ResNet50图像分类
本实验基于PyTorch,在昇腾平台上使用ResNet50对CIFAR10数据集进行图像分类训练。内容涵盖ResNet50的网络架构、残差模块分析及训练代码详解。通过端到端的实战讲解,帮助读者理解如何在深度学习中应用ResNet50模型,并实现高效的图像分类任务。实验包括数据预处理、模型搭建、训练与测试等环节,旨在提升模型的准确率和训练效率。
543 54
|
11月前
|
机器学习/深度学习 算法 PyTorch
PyTorch 实现MobileNetV1用于图像分类
本实验基于PyTorch和昇腾平台,详细讲解了如何使用MobileNetV1模型对CIFAR10数据集进行图像分类。内容涵盖MobileNetV1的特点、网络架构剖析(尤其是深度可分离卷积)、代码实现及训练过程。通过该实验,读者可以掌握轻量级CNN模型在移动端或嵌入式设备中的应用,并了解其在资源受限环境下的高效表现。实验包括数据预处理、模型训练与测试等环节,帮助用户快速上手并优化模型性能。
411 53
|
11月前
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 GoogleNet图像分类
本实验基于PyTorch在昇腾平台上实现GoogleNet模型,针对CIFAR-10数据集进行图像分类。内容涵盖GoogleNet的创新点(如Inception模块、1x1卷积、全局平均池化等)、网络架构解析及代码实战分析。通过详细讲解模型搭建、数据预处理、训练与测试过程,帮助读者掌握如何使用经典CNN模型进行高效图像分类。实验中还介绍了辅助分类器、梯度传播优化等技术细节,并提供了完整的训练和测试代码示例。
|
11月前
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 Alexnet图像分类
本文介绍了在昇腾平台上使用PyTorch实现AlexNet对CIFAR-10数据集进行图像分类的实战。内容涵盖AlexNet的创新点、网络架构解析及代码实现,包括ReLU激活函数、Dropout、重叠最大池化等技术的应用。实验中详细展示了如何构建模型、加载数据集、定义训练和测试模块,并通过60个epoch的训练验证模型性能。
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
579 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers

热门文章

最新文章

推荐镜像

更多