【YOLOv8改进】HAT(Hybrid Attention Transformer,)混合注意力机制 (论文笔记+引入代码)

简介: YOLO目标检测专栏介绍了YOLO系列的改进方法和实战应用,包括卷积、主干网络、注意力机制和检测头的创新。提出的Hybrid Attention Transformer (HAT)结合通道注意力和窗口自注意力,激活更多像素以提升图像超分辨率效果。通过交叉窗口信息聚合和同任务预训练策略,HAT优化了Transformer在低级视觉任务中的性能。实验显示,HAT在图像超分辨率任务上显著优于现有方法。模型结构包含浅层和深层特征提取以及图像重建阶段。此外,提供了HAT模型的PyTorch实现代码。更多详细配置和任务说明可参考相关链接。

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

基于Transformer的方法在低级视觉任务中表现出色,例如图像超分辨率。然而,通过归因分析,我们发现这些网络只能利用输入信息的有限空间范围。这表明Transformer在现有网络中的潜力尚未完全发挥。为了激活更多的输入像素以获得更好的重建效果,我们提出了一种新颖的混合注意力Transformer(Hybrid Attention Transformer, HAT)。它结合了通道注意力和基于窗口的自注意力机制,从而利用了它们能够利用全局统计信息和强大的局部拟合能力的互补优势。此外,为了更好地聚合跨窗口信息,我们引入了一个重叠交叉注意模块,以增强相邻窗口特征之间的交互。在训练阶段,我们还采用了同任务预训练策略,以进一步挖掘模型的潜力。大量实验表明了所提模块的有效性,我们进一步扩大了模型规模,证明了该任务的性能可以大幅提高。我们的方法整体上显著优于最先进的方法,超过了1dB。

创新点

  1. 更多像素的激活:通过结合不同的注意力机制,HAT能够激活更多的输入像素,这在图像超分辨率领域尤为重要,因为它直接关系到重建图像的细节和质量。

  2. 交叉窗口信息的有效聚合:通过重叠交叉注意力模块,HAT模型能够更有效地聚合跨窗口的信息,避免了传统Transformer模型中窗口间信息隔离的问题。

  3. 针对图像超分辨率优化的预训练策略:HAT采用的同任务预训练策略针对性强,能够更有效地利用大规模数据预训练的优势,提高模型在特定超分辨率任务上的表现。

    HAT模型整体架构:

  4. 浅层特征提取(Shallow Feature Extraction)
    输入图像首先通过一个浅层特征提取模块,该模块使用卷积操作提取初始的低层次特征。

  5. 深层特征提取(Deep Feature Extraction)
    提取的特征输入到多个Residual Hybrid Attention Groups (RHAG)中进行深层特征提取。每个RHAG由多个Hybrid Attention Blocks (HAB)和一个Overlapping Cross-Attention Block (OCAB)组成。

  6. 图像重建(Image Reconstruction)
    最后,经过深层特征提取的特征通过一个图像重建模块,将高层次特征转化为输出的超分辨率图像。

yolov8 引入

class HAT(nn.Module):
    r"""混合注意力变换器 (Hybrid Attention Transformer)
        该PyTorch实现基于 `Activating More Pixels in Image Super-Resolution Transformer`。
        部分代码基于SwinIR。
    参数:
        img_size (int | tuple(int)): 输入图像大小。默认值64
        patch_size (int | tuple(int)): Patch大小。默认值1
        in_chans (int): 输入图像通道数。默认值3
        embed_dim (int): Patch嵌入维度。默认值96
        depths (tuple(int)): 每个Swin Transformer层的深度。
        num_heads (tuple(int)): 不同层的注意力头数量。
        window_size (int): 窗口大小。默认值7
        mlp_ratio (float): MLP隐藏层维度与嵌入维度的比例。默认值4
        qkv_bias (bool): 如果为True,为查询、键和值添加可学习的偏差。默认值True
        qk_scale (float): 如果设置,覆盖head_dim ** -0.5的默认qk缩放比例。默认值None
        drop_rate (float): Dropout率。默认值0
        attn_drop_rate (float): 注意力dropout率。默认值0
        drop_path_rate (float): 随机深度率。默认值0.1
        norm_layer (nn.Module): 规范化层。默认值nn.LayerNorm。
        ape (bool): 如果为True,为Patch嵌入添加绝对位置嵌入。默认值False
        patch_norm (bool): 如果为True,在Patch嵌入后添加规范化。默认值True
        use_checkpoint (bool): 是否使用checkpointing来节省内存。默认值False
        upscale: 上采样因子。用于图像SR的2/3/4/8,1用于降噪和压缩伪影减少
        img_range: 图像范围。1或255。
        upsampler: 重建模块。'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
        resi_connection: 残差连接前的卷积块。'1conv'/'3conv'
    """

    def __init__(self,
                 in_chans=3,
                 img_size=64,
                 patch_size=1,
                 embed_dim=96,
                 depths=(6, 6, 6, 6),
                 num_heads=(6, 6, 6, 6),
                 window_size=7,
                 compress_ratio=3,
                 squeeze_factor=30,
                 conv_scale=0.01,
                 overlap_ratio=0.5,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm,
                 ape=False,
                 patch_norm=True,
                 use_checkpoint=False,
                 upscale=2,
                 img_range=1.,
                 upsampler='',
                 resi_connection='1conv',
                 **kwargs):
        super(HAT, self).__init__()

        self.window_size = window_size
        self.shift_size = window_size // 2
        self.overlap_ratio = overlap_ratio

        num_in_ch = in_chans
        num_out_ch = in_chans
        num_feat = 64
        self.img_range = img_range
        if in_chans == 3:
            rgb_mean = (0.4488, 0.4371, 0.4040)
            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        else:
            self.mean = torch.zeros(1, 1, 1, 1)
        self.upscale = upscale
        self.upsampler = upsampler

        # 相对位置索引
        relative_position_index_SA = self.calculate_rpi_sa()
        relative_position_index_OCA = self.calculate_rpi_oca()
        self.register_buffer('relative_position_index_SA', relative_position_index_SA)
        self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)

        # ------------------------- 1,浅层特征提取 ------------------------- #
        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

        # ------------------------- 2,深层特征提取 ------------------------- #
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = embed_dim
        self.mlp_ratio = mlp_ratio

        # 将图像分割为非重叠patch
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=embed_dim,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # 将非重叠的patch合并为图像
        self.patch_unembed = PatchUnEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=embed_dim,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # 绝对位置嵌入
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # 随机深度
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度衰减规则

        # 构建残差混合注意力组 (RHAG)
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = RHAG(
                dim=embed_dim,
                input_resolution=(patches_resolution[0], patches_resolution[1]),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                compress_ratio=compress_ratio,
                squeeze_factor=squeeze_factor,
                conv_scale=conv_scale,
                overlap_ratio=overlap_ratio,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 对SR结果无影响
                norm_layer=norm_layer,
                downsample=None,
                use_checkpoint=use_checkpoint,
                img_size=img_size,
                patch_size=patch_size,
                resi_connection=resi_connection)
            self.layers.append(layer)
        self.norm = norm_layer(self.num_features)

        # 构建深层特征提取中的最后一个卷积层
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        elif resi_connection == 'identity':
            self.conv_after_body = nn.Identity()

        # ------------------------- 3,高质量图像重建 ------------------------- #
        if self.upsampler == 'pixelshuffle':
            # 用于经典SR
            self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
            self.upsample = Upsample(upscale, num_feat)
            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.apply(self._init_weights)

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139142532

相关文章
|
5月前
|
机器学习/深度学习 编解码 计算机视觉
【YOLOv10改进-注意力机制】HAT(Hybrid Attention Transformer,)混合注意力机制
YOLOv10专栏介绍了一种名为HAT的新方法,旨在改善Transformer在图像超分辨率中的表现。HAT结合通道和窗口注意力,激活更多像素并增强跨窗口信息交互。亮点包括:1) 更多像素激活,2) 有效跨窗口信息聚合,3) 任务特定的预训练策略。HAT模型包含浅层特征提取、深层特征提取和图像重建阶段。提供的代码片段展示了HAT类的定义,参数包括不同层的深度、注意力头数量、窗口大小等。欲了解更多详情和配置,请参考给定链接。
|
机器学习/深度学习 编解码 文字识别
Hybrid Multiple Attention Network for Semantic Segmentation in Aerial Images(二)
Hybrid Multiple Attention Network for Semantic Segmentation in Aerial Images
206 0
Hybrid Multiple Attention Network for Semantic Segmentation in Aerial Images(二)
|
4月前
|
SQL 网络协议 数据库连接
【Azure 应用服务】遇见“无法创建hybrid connection for App Service”的解决办法
【Azure 应用服务】遇见“无法创建hybrid connection for App Service”的解决办法
|
编解码 移动开发 JavaScript
开发Hybrid App的技术选型
开发Hybrid App的技术选型
198 0
|
移动开发 前端开发 Java
移动端开发必备知识-Hybrid App
开发必备知识-Hybrid App
656 0
|
移动开发 小程序 JavaScript
Hybrid app本地开发如何调用JSBridge
Hybrid app本地开发如何调用JSBridge
191 0
Hybrid app本地开发如何调用JSBridge
|
Web App开发 XML 移动开发
Hybrid App 应用 开发中 9 个必备知识点复习(WebView / 调试 等) 下
Hybrid App 应用 开发中 9 个必备知识点复习(WebView / 调试 等) 下
531 0
|
存储 移动开发 缓存
Hybrid App 应用 开发中 9 个必备知识点复习(WebView / 调试 等) 上
Hybrid App 应用 开发中 9 个必备知识点复习(WebView / 调试 等) 上
378 0