表现优于ViT和DeiT,华为利用内外Transformer块构建新型视觉骨干模型TNT

简介: 华为诺亚实验室的研究者提出了一种新型视觉 Transformer 网络架构 Transformer in Transformer,它的表现优于谷歌的 ViT 和 Facebook 的 DeiT。论文提出了一个全新的 TNT 模块(Transformer iN Transformer),旨在通过内外两个 transformer 联合提取图像局部和全局特征。通过堆叠 TNT 模块,研究者搭建了全新的纯 Transformer 网络架构——TNT。值得注意的是,TNT 还暗合了 Geoffrey Hinton 最新提出的 part-whole hierarchies 思想。在 ImageNet 图像

微信图片_20211205095640.jpgTransformer 网络推动了诸多自然语言处理任务的进步,而近期 transformer 开始在计算机视觉领域崭露头角。例如,DETR 将目标检测视为一个直接集预测问题,并使用 transformer 编码器 - 解码器架构来解决它;IPT 利用 transformer 在单个模型中处理多个底层视觉任务。与现有主流 CNN 模型(如 ResNet)相比,这些基于 transformer 的模型在视觉任务上也显示出了良好的性能。

谷歌 ViT(Vision Transformer)模型是一个用于视觉任务的纯 transformer 经典技术方案。它将输入图片切分为若干个图像块(patch),然后将 patch 用向量来表示,用 transformer 来处理图像 patch 序列,最终的输出做图像识别。但是 ViT 的缺点也十分明显,它将图像切块输入 Transformer,图像块拉直成向量进行处理,因此,图像块内部结构信息被破坏,忽略了图像的特有性质。

微信图片_20211205095551.jpg

图 1:谷歌 ViT 网络架构。

在这篇论文中,来自华为诺亚实验室的研究者提出一种用于基于结构嵌套的 Transformer 结构,被称为 Transformer-iN-Transformer (TNT) 架构。同样地,TNT 将图像切块,构成 Patch 序列。不过,TNT 不把 Patch 拉直为向量,而是将 Patch 看作像素(组)的序列。

微信图片_20211205095554.jpg


论文链接:https://arxiv.org/pdf/2103.00112.pdf

具体而言,新提出的 TNT block 使用一个外 Transformer block 来对 patch 之间的关系进行建模,用一个内 Transformer block 来对像素之间的关系进行建模。通过 TNT 结构,研究者既保留了 patch 层面的信息提取,又做到了像素层面的信息提取,从而能够显著提升模型对局部结构的建模能力,提升模型的识别效果。

在 ImageNet 基准测试和下游任务上的实验均表明了该方法在精度和计算复杂度方面的优越性。例如, TNT-S 仅用 5.2B FLOPs 就达到了 81.3% 的 ImageNet top-1 正确率,这比计算量相近的 DeiT 高出了 1.5%。


方法


图像预处理


图像预处理主要是将 2D 图像转化为 transformer 能够处理的 1D 序列。这里将图像转化成 patch embedding 序列和 pixel embedding 序列。图像首先被均匀切分成若干个 patch,每个 patch 通过 im2col 操作转化成像素向量序列,像素向量通过线性层映射为 pixel embedding。而 patch embedding(包括一个 class token)是一组初始化为零的向量。具体地,对于一张图像,研究者将其均匀切分为 n 个 patch:

微信图片_20211205095559.jpg


其中是 patch 的尺寸。

Pixel embedding 生成:对于每个 patch,进一步通过 pytorch unfold 操作将其转化成 m 个像素向量,然后用一个全连接层将 m 个像素向量映射为 m 个 pixel embedding:微信图片_20211205095602.jpg


其中微信图片_20211205095606.jpg微信图片_20211205095609.jpg,c 是 pixel embedding 的长度。N 个 patch 就有 n 个 pixel embedding 组:微信图片_20211205095612.jpg


Patch embedding 生成:初始化 n+1 个 patch embedding 来存储模型的特征,它们都初始化为零:

微信图片_20211205095615.jpg


其中第一个 patch embedding 又叫 class token。

Position encoding:对每个 patch embedding 加一个 patch position encoding: 

微信图片_20211205095618.jpg

微信图片_20211205095628.jpg


对每个 pixel embedding 加一个 pixel position encoding:

微信图片_20211205095632.jpg

微信图片_20211205095635.jpg


两种 Position encoding 在训练过程中都是可学习的参数。

微信图片_20211205095640.jpg

图 2:位置编码。


Transformer in Transformer 架构


TNT 网络主要由若干个 TNT block 堆叠构成,这里首先介绍 TNT block。TNT block 有 2 个输入,一个是 pixel embedding,一个是 patch embedding。对应地, TNT block 包含 2 个标准的 transformer block。

如下图 3 所示,研究者只展示了一个 patch 对应的 TNT block,其他 patch 是一样的操作。首先,该 patch 对应的 m 个 pixel embedding 输入到内 transformer block 进行特征处理,输出处理过的 m 个 pixel embedding。Patch embedding 输入到外 transformer block 进行特征处理。其中,这 m 个 pixel embedding 拼接起来构成一个长向量,通过一个全连接层映射到 patch embedding 所在的空间,加到 patch embedding 上。最终,TNT block 输出处理过后的 pixel embedding 和 patch embedding。

微信图片_20211205095647.jpg

图 3:Transformer in Transformer 架构。


通过堆叠 L 个 TNT block,构成了 TNT 网络结构,如下表 1 所示,其中 depth 是 block 个数,#heads 是 Multi-head attention 的头个数。


微信图片_20211205095650.jpg

表 1:TNT 网络结构参数。


实验


ImageNet 实验


研究者在 ImageNet 2012 数据集上训练和验证 TNT 模型。从下表 2 可以看出,在纯 transformer 的模型中,TNT 优于所有其他的纯 transformer 模型。TNT-S 达到 81.3% 的 top-1 精度,比基线模型 DeiT-S 高 1.5%,这表明引入 TNT 框架有利于在 patch 中保留局部结构信息。通过添加 SE 模块,进一步改进 TNT-S 模型,得到 81.6% 的 top-1 精度。与 CNNs 相比,TNT 的性能优于广泛使用的 ResNet 和 RegNet。不过,所有基于 transformer 的模型仍然低于使用特殊 depthwise 卷积的 EfficientNet,因此如何使用纯 transformer 打败 EfficientNet 仍然是一个挑战。

微信图片_20211205095653.jpg

表 2:TNT 与其他 SOTA 模型在 ImageNet 数据集上的对比。


在精度和 FLOPS、参数量的 trade-off 上,TNT 同样优于纯 transformer 模型 DeiT 和 ViT,并超越了 ResNet 和 RegNet 代表的 CNN 模型。具体表现如下图 4 所示:

微信图片_20211205095656.jpg

图 4:TNT 与其他 SOTA 模型在精度、FLOPS 和参数量指标上的变化曲线。


特征图可视化


研究者将学习到的 DeiT 和 TNT 特征可视化,以进一步探究该方法的工作机制。为了更好地可视化,输入图像的大小被调整为 1024x1024。此外,根据空间位置对 patch embedding 进行重排,形成特征图。第 1、6 和 12 个 block 的特征图如下图 5(a) 所示,其中每个块随机抽取 12 个特征图。与 DeiT 相比,TNT 能更好地保留局部信息。

研究者还使用 t-SNE 对输出特征进行可视化(图 5(b))。由此可见,TNT 的特征比 DeiT 的特征更为多样,所包含的信息也更为丰富。这要归功于内部 transformer block 的引入,能够建模局部特征。

微信图片_20211205095659.jpg

图 5:DeiT 和 TNT 特征图可视化。


迁移学习实验


为了证明 TNT 具有很强的泛化能力,研究者在 ImageNet 上训练的 TNT-S、TNT-B 模型迁移到其他数据集。更具体地说,他们在 4 个图像分类数据集上评估 TNT 模型,包括 CIFAR-10、CIFAR-100、Oxford IIIT Pets 和 Oxford 102 Flowers。所有模型微调的图像分辨率为 384x384。

下表 3 对比了 TNT 与 ViT、DeiT 和其他网络的迁移学习结果。研究者发现,TNT 在大多数数据集上都优于 DeiT,这表明在获得更好的特征时,对像素级关系进行建模具有优越性。

微信图片_20211205095702.jpg

表 3:TNT 在下游任务的表现。


总结


该研究提出了一种用于视觉任务的 transformer in transformer(TNT)网络结构。TNT 将图像均匀分割为图像块序列,并将每个图像块视为像素序列。本文还提出了一种 TNT block,其中外 transformer block 用于处理 patch embedding,内 transformer block 用于建模像素嵌入之间的关系。在线性层投影后,将像素嵌入信息加入到图像块嵌入向量中。通过堆叠 TNT block,构建全新 TNT 架构。与传统的视觉 transformer(ViT)相比,TNT 能更好地保存和建模局部信息,用于视觉识别。在 ImageNet 和下游任务上的大量实验都证明了所提出的 TNT 架构的优越性。

相关文章
|
SQL 人工智能 Dart
Android Studio的插件生态非常丰富
Android Studio的插件生态非常丰富
745 1
|
5月前
|
传感器 前端开发 开发者
《揭秘UMD:让模块在千种环境中找到归宿的逻辑》
本文聚焦前端领域的UMD规范,解析其作为跨环境模块化解决方案的核心价值。UMD并非简单拼接现有规范,而是基于对不同环境本质的洞察,构建动态适配逻辑。其核心在于通过多维度环境探测,识别运行时的“特征图谱”,进而匹配对应的模块导出策略,从严格规范环境到极简环境均能适配。 在跨环境适配中,UMD展现出对依赖管理的弹性处理,以及核心逻辑与适配层的分离设计,同时精细化处理全局对象、作用域隔离等细节,确保在各类环境(包括边缘环境)中稳定运行。其深层价值在于重构模块与环境的关系,为前端模块化提供了融合差异、连接多样生态的思维方式,至今仍具重要实践意义。
156 0
|
10月前
|
缓存 负载均衡 算法
有哪些方法可以提高硬件负载均衡设备的性能?
有哪些方法可以提高硬件负载均衡设备的性能?
352 58
技术经验解读:三维空间中直角坐标与球坐标的相互转换
技术经验解读:三维空间中直角坐标与球坐标的相互转换
699 0
|
安全 Android开发 开发者
安卓手机系统的优势和劣势分析
【2月更文挑战第8天】 安卓(Android)是目前全球最流行的移动操作系统之一,拥有强大的开源技术和丰富的应用生态系统。本文将从多个维度对安卓系统进行分析,并探讨其优势和劣势。
2085 2
|
NoSQL Redis 数据库
Redis 如何实现发布订阅?
Redis 如何实现发布订阅?
301 0
Redis 如何实现发布订阅?
|
存储 程序员 编译器
【C/C++ 堆栈以及虚拟内存分段 】C/C++内存分布/管理:代码区、数据区、堆区、栈区和常量区的探索
【C/C++ 堆栈以及虚拟内存分段 】C/C++内存分布/管理:代码区、数据区、堆区、栈区和常量区的探索
1101 0
|
机器学习/深度学习 PyTorch 算法框架/工具
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
542 0
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
|
弹性计算 虚拟化 异构计算
阿里云GPU服务器NVIDIA A100 GPU卡租用价格表
阿里云GPU服务器NVIDIA A100 GPU卡租用价格表,阿里云GPU服务器租用价格表包括包年包月价格、一个小时收费以及学生GPU服务器租用费用,阿里云GPU计算卡包括NVIDIA V100计算卡、T4计算卡、A10计算卡和A100计算卡,GPU云服务器gn6i可享受3折优惠,阿里云百科分享阿里云GPU服务器租用价格表、GPU一个小时多少钱以及学生GPU服务器收费价格表
15797 0
阿里云GPU服务器NVIDIA A100 GPU卡租用价格表
|
JSON 小程序 数据格式
小程序引入vant weapp组件的方法
小程序引入vant weapp组件的方法