引入特征空间,显著降低计算量:双边局部注意力ViT性能媲美全局注意力

简介: 引入特征空间,显著降低计算量:双边局部注意力ViT性能媲美全局注意力

这项研究中,来自百度研究院和香港大学的研究者重新思考了局部自注意力机制,提出了特征空间局部注意力(feature-space local attention简称FSLA)。


Vision Transformer 舍弃了 ConvNet 先验信息,通过引入自注意力机制对远距离特征依赖进行建模,提升了模型的表征能力。然而 Vision Transformer 的自注意力机制在图像分辨率较高时,计算复杂度过高。为了克服这个问题,研究人员使用局部窗口计算自注意力,在此称之为图像空间局部注意力(image-space local attention 或简称 ISLA)。尽管基于窗口的图像空间局部注意力显著提升了效率,但仍面临难以捕捉远距离特征依赖的问题。

在这项研究中,来自百度研究院和香港大学的研究者重新思考了局部自注意力机制,提出了特征空间局部注意力(feature-space local attention 或简称 FSLA)。这种局部注意力从图像内容出发,把特征相似的 token 聚成类,并且只在每类特征的内部计算自注意力,相比全局自注意力显著降低了计算量,同时基本保留了原始的全局自注意力机制对远距离特征依赖的建模能力。

为了将特征空间局部注意力与图像空间局部注意力相结合,本文作者进一步提出了双边局部注意力 ViT (简称 BOAT),把特征空间局部注意力模块加入到现有的基于窗口的局部注意力视觉 Transformer 模型中,作为图像空间局部注意力的补充,大大提升了针对远距离特征依赖的建模能力,在几个基准数据集上的大量实验表明结合了特征空间局部注意力的模型明显优于现有的 ConvNet 和 ViT 模型。



创新动机

为了保持更高的特征图分辨率,同时不会带来过高的运算复杂度,现有的图像空间局部注意力机制将一个图像划分为多个局部窗口,自注意力只在同一窗口的 token 间运算。这是一个合理的设计,因为一个 token 很可能与空间上邻近的 token 相关联。因此,局限于局部窗口的自注意力很可能不会显著降低性能,但是可以显著降低计算量。

本文重新思考了局部自注意力,从特征角度而非空间角度,实现了对局部窗口的划分。具体来说,图像空间局部自注意力的依据是:在空间上邻近的 token 很可能对彼此施加更大的影响(图 1 左);而本文提出的特征空间局部自注意力机制的依据是:即使在图像空间距离较远但在特征空间距离较近的 token 同样会对彼此有很大的影响,因此它在特征空间对 token 进行聚类,并且只在每类特征的内部如同空间局部窗口一样计算自注意力(图 1 右)。


本文提出的特征空间局部自注意力仅计算特征空间内最近邻的特征向量间的注意力,将距离较远的特征向量间的影响直接设为 0。这本质上定义了一个分段相似度函数,将相似度小的特征向量间的注意力近似为 0,降低了运算复杂度。与图像空间局部自注意力相比,特征空间局部自注意力在 ViT 模型中运用的较少。特征空间局部自注意力关注的是相似度较高的特征向量间的注意力,而不考虑空间上两者的邻近程度。因此,它是图像空间局部自注意力的很好补充,能对因跨越空间局部窗口而被遗漏的远距离特征依赖进行建模。

方法概述

本文中的 ViT 采用了和 Swin 和 CSWin 相同的层次化金字塔架构,由一个 patch embedding 模块和若干个双边局部注意力(bilateral local attention)模块组成。


本文与 Swin 和 CSWin 的主要区别是其中的 local attention 模块被替换成了下图所示的 bilateral local attention。而 patch embedding,position encoding 等设计皆和 Swin/CSWin 保持一致,所以接下来对 bilateral local attention 进行详细介绍。

Bilateral Local Attention

本文提出的 bilateral local attention 在基于窗口的图像空间局部注意力(ISLA)模型中添加了特征空间局部注意力(FSLA)模块。FSLA 模块根据 ISLA 模块的输出计算在特征空间彼此邻近的 token 之间的注意力:


最后,将 FSLA 模块的输出送入另一个归一化层和一个 MLP 模块进行处理,再通过一个短路连接得到整个 bilateral local attention 模块的输出:


FSLA 的重点是如何对特征进行聚类操作,并且在各个类内部计算自注意力。最直觉的方法是使用 K-means 聚类,但 K-means 聚类不能确保分组结果大小相同,这使得在 GPU 平台上难以有效地实现并行加速,同时也可能对自注意力计算的有效性产生负面影响。

因此本文提出均衡层次聚类,它进行 k 层聚类。在每一层,它进行均衡二分聚类,将上一层的各个类组均衡地划分为两个更小的类组。如下图所示,所有 token 分成了 token 数量相同的 8 个类组,然后在每组内部计算自注意力,具体的自注意力参数和图像空间局部注意力保持一致。


假如某个类组原先有 2m 个 token,均衡二分聚类后得到的每组的 token 数量为 m。与 K-means 类似,均衡二分聚类是一个迭代算法并且依赖于聚类中心。如以下算法所示,在每次迭代对所有 token 进行分组时,先计算每个 token 到两个聚类中心的距离比值,然后把所有 token 按距离比值的递减顺序排序,最后将排序列表前半部分 m 个 token 赋给第一组,后半部分 m 个 token 赋给第二组。


需要注意的是,这样进行无重叠的均衡二分聚类可能会导致两个处于排序列表中段位置的、特征比较相似的 token 被分配到两个不同的类组中,从而无法计算它们之间的相互影响。因此在实际计算中,为了避免遗漏邻近特征间的影响,会保留一定程度的类间重叠,也就是把排序列表的最前面 m+n 个 token 赋给第一组,最后 m+n 个 token 赋给第二组。这样两组之间就存在 2n 个重用的 token,这样的类间重叠会导致额外的运算,因此实际只在层次聚类的最后一层进行有重叠的均衡二分聚类。完成聚类以后,在每组 token 内部按照常规操作进行自注意力机制的计算即可。

值得注意的是,本文中的所有聚类都是临时计算的,不包含任何可学习的参数,因此不存在对聚类算法本身进行梯度回传的问题。此外,所有聚类运算都用 GPU 进行了加速,对模型的整体计算量影响不大。

实验结果

BOAT 遵循与其它 ViT 相同的训练策略。本文使用 ImageNet-1K 的训练集训练模型,输入图像使用 224×224 分辨率,并且没有外部数据。

具体来说,训练 300 个 epochs,使用 AdamW 优化器、余弦学习速率调度器和一个线性预热过程。BOAT 在多个数据集上都取得了 SOTA 的效果。比如,在 ImageNet-1K 测试集上,BOAT-CSWin-T 取得了 83.7 的 Top-1 分类准确率;在 ADE20K 语义分割测试集上,BOAT-CSWin-T 的 mIoU 达到了 50.5。


相关文章
|
文字识别 并行计算 语音技术
ModelScope问题之下载模型文件报错如何解决
ModelScope模型报错是指在使用ModelScope平台进行模型训练或部署时遇到的错误和问题;本合集将收集ModelScope模型报错的常见情况和排查方法,帮助用户快速定位问题并采取有效措施。
3829 3
|
存储 JSON 自然语言处理
手把手教你使用ModelScope训练一个文本分类模型
手把手教你使用ModelScope训练一个文本分类模型
|
机器学习/深度学习 数据采集 自然语言处理
ModelScope保姆式教程带你玩转语言生成模型
PALM预训练语言生成模型是针对实际场景中常见的文本生成需求所设计的一个模型。模型利用大量无监督数据,通过结合自编码和自回归任务进行预训练,更贴合下游生成任务所同时需要的理解和生成能力。
34943 4
ModelScope保姆式教程带你玩转语言生成模型
|
缓存 安全 关系型数据库
PolarDB 阿里云国产化数据库:linux系统下的详细安装步骤手册
PolarDB 阿里云国产化数据库:linux系统下的详细安装步骤手册
5829 0
PolarDB 阿里云国产化数据库:linux系统下的详细安装步骤手册
|
人工智能 自然语言处理 物联网
llama factory 从数据集起步 跑通 qwen系列开源生成式大模型 微调
`dataset_info.json` 文件用于管理 llama factory 中的所有数据集,支持 `alpaca` 和 `sharegpt` 格式。通过配置此文件,可以轻松添加自定义数据集。数据集的相关参数包括数据源地址、数据集格式、样本数量等,支持 Hugging Face 和 ModelScope 两个平台的数据集仓库。针对不同格式的数据集,提供了详细的配置示例,如 `alpaca` 格式的指令监督微调数据集、偏好数据集等,以及 `sharegpt` 格式的多模态数据集等。今天我们通过自定义数据集的方式来进行qwen2.5_14B_instruct模型进行微调
7300 7
|
8月前
|
机器学习/深度学习 监控 算法
基于mediapipe深度学习的手势数字识别系统python源码
本内容涵盖手势识别算法的相关资料,包括:1. 算法运行效果预览(无水印完整程序);2. 软件版本与配置环境说明,提供Python运行环境安装步骤;3. 部分核心代码,完整版含中文注释及操作视频;4. 算法理论概述,详解Mediapipe框架在手势识别中的应用。Mediapipe采用模块化设计,包含Calculator Graph、Packet和Subgraph等核心组件,支持实时处理任务,广泛应用于虚拟现实、智能监控等领域。
|
8月前
|
存储 Web App开发 缓存
清理C盘空间的6种方法,附详细操作步骤
释放C盘空间并不难。只要掌握合适的方法,哪怕你是电脑小白,也能轻松清理出几十GB空间。下面就为大家介绍6种实用、安全、细致的清理方法,并附上操作步骤。
|
人工智能 文字识别 自然语言处理
文档图像多模态大模型最新技术探索
文档图像多模态大模型最新技术探索
1082 0