ECCV 2022 | 在视觉Transformer上进行递归,不增参数,计算量还少

简介: ECCV 2022 | 在视觉Transformer上进行递归,不增参数,计算量还少
今天跟大家分享一篇来自CMU等机构的论文《Sliced Recursive Transformer》,该论文已被 ECCV 2022 接收。

目前 vision transformer 在不同视觉任务上如分类、检测等都展示出了强大的性能,但是其巨大的参数量和计算量阻碍了该模型进一步在实际场景中的应用。基于这个考虑,本文重点研究了如何在不增加额外参数量的前提下把模型的表达能力挖掘到极致,同时还要保证模型计算量在合理范围内,从而可以在一些存储容量小,计算能力弱的嵌入式设备上部署。


基于这个动机,Zhiqiang Shen、邢波等研究者提出了一个 SReT 模型,通过循环递归结构来强化每个 block 的特征表达能力,同时又提出使用多个局部 group self-attention 来近似 vanilla global self-attention,在显著降低计算量 FLOPs 的同时,模型没有精度的损失。





总结而言,本文主要有以下两个创新点:

  1. 使用类似 RNN 里面的递归结构(recursive block)来构建 ViT 主体,参数量不涨的前提下提升模型表达能力;
  2. 使用 CNN 中 group-conv 类似的 group self-attention 来降低 FLOPs 的同时保持模型的高精度;


此外,本文还有其他一些小的改动:

  1. 网络最前面使用三层连续卷积,卷积核为 3x3,结构直接使用了研究者之前 DSOD 里面的 stem 结构;
  2. Knowledge distillation 只使用了单独的 soft label,而不是 DeiT 里面 hard 形式的 label 加 one-hot ground-truth,因为研究者认为 soft label 包含的信息更多,更有利于知识蒸馏;
  3. 使用可学习的 residual connection 来提升模型表达能力;


如下图所示,本文所提出的模型在参数量(Params)和计算量(FLOPs)方面相比其他模型都有明显的优势:


下面我们来解读这篇文章:

1.ViT 中的递归模块

递归操作的基本组成模块如下图:


该模块非常简单明了,类似于 RNN 结构,将模块当前 step 的输出作为下个 step 的输入重新输进该模块,从而增强模型特征表达能力。

研究者展示了将该设计直接应用在 DeiT 上的结果,如下所示:


可以看到在加入额外一次简单递归操作之后就可以得到将近 2% 的精度提升。

当然具体到全局网络结构层面还有不同的递归构建方法,如下图:


其中 NLL 层(Non-linear Projection Layer)是用来保证每个递归模块输入输出不完全一致。论文提出使用这个模块的主要原因是发现在上述 Table 1 里面更多次数的递归操作并没有进一步提升性能,说明网络可能学到了一个比较简单的状态,而 NLL 层可以强制模型输入输出不一致从而缓解这种情况。同时,研究者从实验结果发现上图 (1) internal loop 相比 external loop 设计拥有更好的 accuracy-FLOPs 结果。

2. 分组的 Group Self-attention 模块

如下图所示,研究者提出了一种分组的 group self-attention 策略来降低模型的 FLOPs,同时保证 self-attention 的全局注意力,从而使得模型没有明显精度损失:


Group Self-attention 模块具体形式如下:


Group self-attention 的缺点是只有局部区域会相互作用,研究者提出通过使用 Permutation 操作来近似全局 self-attention 的机制,同时通过 Inverse Permutation 来复原和保留 tokens 的次序信息,针对这个部分的消融实验如下所示:


其中 P 表示加入 Permutation,I 表示加入 Inverse  Permutation,-L 表示如果 group 数为 1,就不使用 P 和 I(比如模型最后一个 stage)。根据上述表格的结果,研究者最后采用了 [8, 2][4,1][1,1] 这种分组设计。

3. 其他设计

可学习的残差结构 (LRC):


研究者尝试了上图三种结构,图(3)结果最佳。具体而言,研究者在每个模块里面添加了 6 个额外参数(4+2,2 个在 NLL 层),这些参数会跟模型其他参数一起学习,从而使网络拥有更强的表达能力,参数初始化都为 1,在训练过程 6 个参数的数值变化情况如下所示:




Stem 结构组成:


如上表所示,Stem 由三个 3x3 的连续卷积组成,每个卷积 stride 为 2。

整体网络结构:

研究者进一步去掉了 class token 和 distillation token,并且发现精度有少量提升。


消融实验:


模型混合深度训练:

研究者进一步发现分组递归设计还有一个好处就是:可以支持模型混合深度训练,这种训练方式可以大大降低深度网络结构优化复杂度,研究者展示了 108 层不同模型结构优化过程的 landscape 可视化,如下图所示,可以很明显的看到混合深度结构优化过程困难程度显著低于另外两种结构。


最后,分组 group self-attention 算法 PyTorch 伪代码如下:


更多方法和实验细节可以阅读原论文和 GitHub 代码。


相关文章
|
5月前
|
IDE 安全 Java
Lombok 在企业级 Java 项目中的隐性成本:便利背后的取舍之道
Lombok虽能简化Java代码,但其“魔法”特性易破坏封装、影响可维护性,隐藏调试难题,且与JPA等框架存在兼容风险。企业级项目应优先考虑IDE生成、Java Records或MapStruct等更透明、稳健的替代方案,平衡开发效率与系统长期稳定性。
254 1
|
6月前
|
存储 固态存储 IDE
移动硬盘盒,机械硬盘和固态硬盘通用吗?
移动硬盘盒能否同时支持机械硬盘(HDD)和固态硬盘(SSD)?本文详解硬盘盒的兼容性问题,涵盖接口类型(如SATA、NVMe)、尺寸规格(2.5英寸、3.5英寸、M.2)及使用体验差异,助你正确选择适配的硬盘盒,确保兼容与性能兼顾。
|
9月前
|
算法 Java 调度
Java多线程基础
本文主要讲解多线程相关知识,分为两部分。第一部分涵盖多线程概念(并发与并行、进程与线程)、Java程序运行原理(JVM启动多线程特性)、实现多线程的两种方式(继承Thread类与实现Runnable接口)及其区别。第二部分涉及线程同步(同步锁的应用场景与代码示例)及线程间通信(wait()与notify()方法的使用)。通过多个Demo代码实例,深入浅出地解析多线程的核心知识点,帮助读者掌握其实现与应用技巧。
159 1
|
10月前
|
设计模式 机器学习/深度学习 缓存
基于PySide6的聚合翻译软件设计与实现
本项目基于PySide6框架构建多引擎聚合智能翻译系统,解决传统工具单一API依赖、切换繁琐及定制化不足的问题。系统采用分层架构,包含UI层、业务逻辑层和API层,运用策略模式、工厂模式等设计模式提升灵活性。核心功能包括翻译引擎抽象、智能路由选择与异步处理,支持无感切换、动态权重调整及非阻塞交互。优化策略涵盖LRU缓存与三级容错机制,确保高性能与稳定性。系统跨平台发布,具备插件化扩展能力,未来将探索机器学习质量预估与OCR支持等功能,适配企业级需求。
302 11
|
内存技术
除了智能照明系统,PWM 还可以应用在哪些领域
脉冲宽度调制(PWM)技术不仅适用于智能照明系统,还广泛应用于电机控制、电源管理、音频处理和通信系统等领域,以实现高效能的信号和功率控制。
1194 11
|
机器学习/深度学习 算法 测试技术
【博士每天一篇文献-算法】iCaRL_ Incremental Classifier and Representation Learning
本文介绍了iCaRL算法,一种增量分类器和表示学习系统,它能够逐步从数据流中学习新概念,通过使用最近均值示例规则、基于牧羊的样本选择和知识蒸馏等方法,在CIFAR-100和ImageNet数据集上展示了其优越的逐步学习能力和对灾难性遗忘的有效抵抗。
576 0
|
存储 缓存 关系型数据库
MariaDB 和 GreatSQL 性能差异背后的真相
【10月更文挑战第22天】本文介绍了 MariaDB 和 GreatSQL 两款数据库系统的背景、性能差异因素及实际应用场景。MariaDB 是 MySQL 的分支,具有良好的社区支持和丰富的插件生态系统;GreatSQL 是国产的 MySQL 兼容数据库,专注于企业级应用场景。文章详细对比了两者的存储引擎优化、查询优化器差异、缓存机制和并发处理能力,并分析了它们在 OLTP 和 OLAP 场景中的性能表现。
674 3
|
监控 数据挖掘 数据安全/隐私保护
ERP系统中的应收应付管理与风险控制解析
【7月更文挑战第25天】 ERP系统中的应收应付管理与风险控制解析
798 2
|
API 数据库 开发者
【独家揭秘】Django ORM高手秘籍:如何玩转数据模型与数据库交互的艺术?
【8月更文挑战第31天】本文通过具体示例详细介绍了Django ORM的使用方法,包括数据模型设计与数据库操作的最佳实践。从创建应用和定义模型开始,逐步演示了查询、创建、更新和删除数据的全过程,并展示了关联查询与过滤的技巧,帮助开发者更高效地利用Django ORM构建和维护Web应用。通过这些基础概念和实践技巧,读者可以更好地掌握Django ORM,提升开发效率。
262 0
|
存储 编译器 程序员
【C/C++ 虚函数以及替代方案】C++ 虚函数的使用开销以及替代方案(二)
【C/C++ 虚函数以及替代方案】C++ 虚函数的使用开销以及替代方案
357 0

热门文章

最新文章