表现优于ViT和DeiT,华为利用内外Transformer块构建新型视觉骨干模型TNT

简介: 华为诺亚实验室的研究者提出了一种新型视觉 Transformer 网络架构 Transformer in Transformer,它的表现优于谷歌的 ViT 和 Facebook 的 DeiT。论文提出了一个全新的 TNT 模块(Transformer iN Transformer),旨在通过内外两个 transformer 联合提取图像局部和全局特征。通过堆叠 TNT 模块,研究者搭建了全新的纯 Transformer 网络架构——TNT。值得注意的是,TNT 还暗合了 Geoffrey Hinton 最新提出的 part-whole hierarchies 思想。在 ImageNet 图像

微信图片_20211205095640.jpgTransformer 网络推动了诸多自然语言处理任务的进步,而近期 transformer 开始在计算机视觉领域崭露头角。例如,DETR 将目标检测视为一个直接集预测问题,并使用 transformer 编码器 - 解码器架构来解决它;IPT 利用 transformer 在单个模型中处理多个底层视觉任务。与现有主流 CNN 模型(如 ResNet)相比,这些基于 transformer 的模型在视觉任务上也显示出了良好的性能。

谷歌 ViT(Vision Transformer)模型是一个用于视觉任务的纯 transformer 经典技术方案。它将输入图片切分为若干个图像块(patch),然后将 patch 用向量来表示,用 transformer 来处理图像 patch 序列,最终的输出做图像识别。但是 ViT 的缺点也十分明显,它将图像切块输入 Transformer,图像块拉直成向量进行处理,因此,图像块内部结构信息被破坏,忽略了图像的特有性质。

微信图片_20211205095551.jpg

图 1:谷歌 ViT 网络架构。

在这篇论文中,来自华为诺亚实验室的研究者提出一种用于基于结构嵌套的 Transformer 结构,被称为 Transformer-iN-Transformer (TNT) 架构。同样地,TNT 将图像切块,构成 Patch 序列。不过,TNT 不把 Patch 拉直为向量,而是将 Patch 看作像素(组)的序列。

微信图片_20211205095554.jpg


论文链接:https://arxiv.org/pdf/2103.00112.pdf

具体而言,新提出的 TNT block 使用一个外 Transformer block 来对 patch 之间的关系进行建模,用一个内 Transformer block 来对像素之间的关系进行建模。通过 TNT 结构,研究者既保留了 patch 层面的信息提取,又做到了像素层面的信息提取,从而能够显著提升模型对局部结构的建模能力,提升模型的识别效果。

在 ImageNet 基准测试和下游任务上的实验均表明了该方法在精度和计算复杂度方面的优越性。例如, TNT-S 仅用 5.2B FLOPs 就达到了 81.3% 的 ImageNet top-1 正确率,这比计算量相近的 DeiT 高出了 1.5%。


方法


图像预处理


图像预处理主要是将 2D 图像转化为 transformer 能够处理的 1D 序列。这里将图像转化成 patch embedding 序列和 pixel embedding 序列。图像首先被均匀切分成若干个 patch,每个 patch 通过 im2col 操作转化成像素向量序列,像素向量通过线性层映射为 pixel embedding。而 patch embedding(包括一个 class token)是一组初始化为零的向量。具体地,对于一张图像,研究者将其均匀切分为 n 个 patch:

微信图片_20211205095559.jpg


其中是 patch 的尺寸。

Pixel embedding 生成:对于每个 patch,进一步通过 pytorch unfold 操作将其转化成 m 个像素向量,然后用一个全连接层将 m 个像素向量映射为 m 个 pixel embedding:微信图片_20211205095602.jpg


其中微信图片_20211205095606.jpg微信图片_20211205095609.jpg,c 是 pixel embedding 的长度。N 个 patch 就有 n 个 pixel embedding 组:微信图片_20211205095612.jpg


Patch embedding 生成:初始化 n+1 个 patch embedding 来存储模型的特征,它们都初始化为零:

微信图片_20211205095615.jpg


其中第一个 patch embedding 又叫 class token。

Position encoding:对每个 patch embedding 加一个 patch position encoding: 

微信图片_20211205095618.jpg

微信图片_20211205095628.jpg


对每个 pixel embedding 加一个 pixel position encoding:

微信图片_20211205095632.jpg

微信图片_20211205095635.jpg


两种 Position encoding 在训练过程中都是可学习的参数。

微信图片_20211205095640.jpg

图 2:位置编码。


Transformer in Transformer 架构


TNT 网络主要由若干个 TNT block 堆叠构成,这里首先介绍 TNT block。TNT block 有 2 个输入,一个是 pixel embedding,一个是 patch embedding。对应地, TNT block 包含 2 个标准的 transformer block。

如下图 3 所示,研究者只展示了一个 patch 对应的 TNT block,其他 patch 是一样的操作。首先,该 patch 对应的 m 个 pixel embedding 输入到内 transformer block 进行特征处理,输出处理过的 m 个 pixel embedding。Patch embedding 输入到外 transformer block 进行特征处理。其中,这 m 个 pixel embedding 拼接起来构成一个长向量,通过一个全连接层映射到 patch embedding 所在的空间,加到 patch embedding 上。最终,TNT block 输出处理过后的 pixel embedding 和 patch embedding。

微信图片_20211205095647.jpg

图 3:Transformer in Transformer 架构。


通过堆叠 L 个 TNT block,构成了 TNT 网络结构,如下表 1 所示,其中 depth 是 block 个数,#heads 是 Multi-head attention 的头个数。


微信图片_20211205095650.jpg

表 1:TNT 网络结构参数。


实验


ImageNet 实验


研究者在 ImageNet 2012 数据集上训练和验证 TNT 模型。从下表 2 可以看出,在纯 transformer 的模型中,TNT 优于所有其他的纯 transformer 模型。TNT-S 达到 81.3% 的 top-1 精度,比基线模型 DeiT-S 高 1.5%,这表明引入 TNT 框架有利于在 patch 中保留局部结构信息。通过添加 SE 模块,进一步改进 TNT-S 模型,得到 81.6% 的 top-1 精度。与 CNNs 相比,TNT 的性能优于广泛使用的 ResNet 和 RegNet。不过,所有基于 transformer 的模型仍然低于使用特殊 depthwise 卷积的 EfficientNet,因此如何使用纯 transformer 打败 EfficientNet 仍然是一个挑战。

微信图片_20211205095653.jpg

表 2:TNT 与其他 SOTA 模型在 ImageNet 数据集上的对比。


在精度和 FLOPS、参数量的 trade-off 上,TNT 同样优于纯 transformer 模型 DeiT 和 ViT,并超越了 ResNet 和 RegNet 代表的 CNN 模型。具体表现如下图 4 所示:

微信图片_20211205095656.jpg

图 4:TNT 与其他 SOTA 模型在精度、FLOPS 和参数量指标上的变化曲线。


特征图可视化


研究者将学习到的 DeiT 和 TNT 特征可视化,以进一步探究该方法的工作机制。为了更好地可视化,输入图像的大小被调整为 1024x1024。此外,根据空间位置对 patch embedding 进行重排,形成特征图。第 1、6 和 12 个 block 的特征图如下图 5(a) 所示,其中每个块随机抽取 12 个特征图。与 DeiT 相比,TNT 能更好地保留局部信息。

研究者还使用 t-SNE 对输出特征进行可视化(图 5(b))。由此可见,TNT 的特征比 DeiT 的特征更为多样,所包含的信息也更为丰富。这要归功于内部 transformer block 的引入,能够建模局部特征。

微信图片_20211205095659.jpg

图 5:DeiT 和 TNT 特征图可视化。


迁移学习实验


为了证明 TNT 具有很强的泛化能力,研究者在 ImageNet 上训练的 TNT-S、TNT-B 模型迁移到其他数据集。更具体地说,他们在 4 个图像分类数据集上评估 TNT 模型,包括 CIFAR-10、CIFAR-100、Oxford IIIT Pets 和 Oxford 102 Flowers。所有模型微调的图像分辨率为 384x384。

下表 3 对比了 TNT 与 ViT、DeiT 和其他网络的迁移学习结果。研究者发现,TNT 在大多数数据集上都优于 DeiT,这表明在获得更好的特征时,对像素级关系进行建模具有优越性。

微信图片_20211205095702.jpg

表 3:TNT 在下游任务的表现。


总结


该研究提出了一种用于视觉任务的 transformer in transformer(TNT)网络结构。TNT 将图像均匀分割为图像块序列,并将每个图像块视为像素序列。本文还提出了一种 TNT block,其中外 transformer block 用于处理 patch embedding,内 transformer block 用于建模像素嵌入之间的关系。在线性层投影后,将像素嵌入信息加入到图像块嵌入向量中。通过堆叠 TNT block,构建全新 TNT 架构。与传统的视觉 transformer(ViT)相比,TNT 能更好地保存和建模局部信息,用于视觉识别。在 ImageNet 和下游任务上的大量实验都证明了所提出的 TNT 架构的优越性。

相关文章
|
6月前
|
机器学习/深度学习 自然语言处理 Shell
【CaiT】如何才能使VIT网络往更深层发展
【CaiT】如何才能使VIT网络往更深层发展
71 0
|
10天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目DWRSeg是一种高效的实时语义分割网络,通过将多尺度特征提取分为区域残差化和语义残差化两步,提高了特征提取效率。它引入了Dilation-wise Residual (DWR) 和 Simple Inverted Residual (SIR) 模块,优化了不同网络阶段的感受野。在Cityscapes和CamVid数据集上的实验表明,DWRSeg在准确性和推理速度之间取得了最佳平衡,达到了72.7%的mIoU,每秒319.5帧。代码和模型已公开。
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
|
25天前
|
机器学习/深度学习 编解码 负载均衡
MoH:融合混合专家机制的高效多头注意力模型及其在视觉语言任务中的应用
本文提出了一种名为混合头注意力(MoH)的新架构,旨在提高Transformer模型中注意力机制的效率。MoH通过动态注意力头路由机制,使每个token能够自适应选择合适的注意力头,从而在减少激活头数量的同时保持或提升模型性能。实验结果显示,MoH在图像分类、类条件图像生成和大语言模型等多个任务中均表现出色,尤其在减少计算资源消耗方面有显著优势。
40 1
|
29天前
|
机器学习/深度学习 算法 语音技术
超越传统模型:探讨门控循环单元(GRU)在语音识别领域的最新进展与挑战
【10月更文挑战第7天】随着人工智能技术的不断进步,语音识别已经从一个相对小众的研究领域发展成为日常生活中的常见技术。无论是智能手机上的语音助手,还是智能家居设备,甚至是自动字幕生成系统,都离不开高质量的语音识别技术的支持。在众多用于语音识别的技术中,基于深度学习的方法尤其是递归神经网络(RNNs)及其变体如长短期记忆网络(LSTMs)和门控循环单元(GRUs)已经成为了研究和应用的热点。
24 2
|
10天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2融合DWRSeg二次创新C3k2_DWRSeg:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
【YOLO11改进 - C3k2融合】C3k2融合DWRSDWRSeg是一种高效的实时语义分割网络,通过将多尺度特征提取方法分解为区域残差化和语义残差化两步,提高了多尺度信息获取的效率。网络设计了Dilation-wise Residual (DWR) 和 Simple Inverted Residual (SIR) 模块,分别用于高阶段和低阶段,以充分利用不同感受野的特征图。实验结果表明,DWRSeg在Cityscapes和CamVid数据集上表现出色,以每秒319.5帧的速度在NVIDIA GeForce GTX 1080 Ti上达到72.7%的mIoU,超越了现有方法。代码和模型已公开。
|
6月前
|
机器学习/深度学习 编解码 算法
YOLOv8改进 | 主干网络 | 增加网络结构增强小目标检测能力【独家创新——附结构图】
YOLOv8在小目标检测上存在挑战,因卷积导致信息丢失。本文教程将原网络结构替换为更适合小目标检测的backbone,并提供结构图。通过讲解原理和手把手教学,指导如何修改代码,提供完整代码实现,适合新手实践。文章探讨了大特征图对小目标检测的重要性,如细节保留、定位精度、特征丰富度和上下文信息,并介绍了FPN等方法。YOLOv8流程包括预处理、特征提取、融合和检测。修改后的网络结构增加了上采样和concatenate步骤,以利用更大特征图检测小目标。完整代码和修改后的结构图可在文中链接获取。
|
5月前
|
机器学习/深度学习 自然语言处理 物联网
ICML 2024:脱离LoRA架构,训练参数大幅减少,新型傅立叶微调来了
【6月更文挑战第4天】在ICML 2024上,研究团队提出了傅立叶变换微调(FourierFT),一种减少训练参数的新方法,替代了依赖LoRA的微调。FourierFT通过学习权重变化矩阵的稀疏频谱系数,实现了LFMs的高效微调。在多项任务上,FourierFT展示出与LoRA相当或更优的性能,参数量却大幅减少,如在LLaMA2-7B模型上,仅需0.064M参数,对比LoRA的33.5M。广泛实验验证了其在NLP和CV任务上的效果,但未来还需探索其适用性和泛化能力。论文链接:[arxiv.org/abs/2405.03003](https://arxiv.org/abs/2405.03003)
102 0
|
6月前
|
机器学习/深度学习 人工智能
大模型架构将迎来除 Transformer 之外的突破
大模型架构将迎来除 Transformer 之外的突破
116 2
大模型架构将迎来除 Transformer 之外的突破
|
6月前
|
机器学习/深度学习 存储 编解码
沈春华团队最新 | SegViT v2对SegViT进行全面升级,让基于ViT的分割模型更轻更强
沈春华团队最新 | SegViT v2对SegViT进行全面升级,让基于ViT的分割模型更轻更强
110 0
|
机器学习/深度学习 人工智能 编解码
一文梳理视觉Transformer架构进展:与CNN相比,ViT赢在哪儿?(1)
一文梳理视觉Transformer架构进展:与CNN相比,ViT赢在哪儿?
600 0
下一篇
无影云桌面