ICLR 2023 | DIFFormer: 扩散过程启发的Transformer(1)

简介: ICLR 2023 | DIFFormer: 扩散过程启发的Transformer

ICLR 2023 | DIFFormer: 扩散过程启发的Transformer

机器之心 2023-04-29 13:01 发表于辽宁

机器之心专栏

机器之心编辑部

本⽂介绍⼀项近期的研究⼯作,试图建⽴能量约束扩散微分⽅程与神经⽹络架构的联系,从而原创性的提出了物理启发下的 Transformer,称作 DIFFormer。作为⼀种通⽤的可以灵活⾼效的学习样本间隐含依赖关系的编码器架构,DIFFormer 在各类任务上都展现了强大潜⼒。这项工作已被 ICLR 2023 接收,并在⾸轮评审就收到了四位审稿⼈给出的 10/8/8/6 评分(最终均分排名位于前 0.5%)。




简介
如何得到有效的样本表征是机器学习领域的⼀⼤核⼼基础问题,也是深度学习范式在各类下游任务能发挥作用的重要前提。传统的表征学习⽅法通常假设每个输⼊样本是独⽴的,即分别将每个样本输⼊进 encoder ⽹络得到其在隐空间中的表征,每个样本的前向计算过程互不干扰。然⽽这⼀假设通常与现实物理世界中数据的⽣成过程是违背的:由于显式的物理连接或隐含的交互关系,每个观测样本之间可能存在相互的依赖

这⼀观察也启发了我们去重新思考⽤于表征计算的 encoder ⽹络设计:是否能设计⼀种新型的 encoder ⽹络能够在前向计算中显式的利⽤样本间的依赖关系(尽管 这些依赖关系是未被观察到的)。在这个⼯作中,我们从两个物理学原理出发,将神经⽹络计算样本表征的前向过程看作给定初始状态的扩散过程,且随着时间的推移(层数加深)系统的整体能量不断下降(见下图)。

DIFFormer 模型主要思想的示意图:将模型计算样本表征的前向过程看作⼀个扩散过程,随着时间的推移,节点之间存在信号传递,且任意节点对之间信号传递的速率会随着时间适应性的变化,使得系统整体的能量最⼩化。通过扩散过程和能量约束,最终的样本表征能够吸收个体和全局的信息,更有助于下游任务。


通过试图建⽴扩散微分⽅程与神经⽹络架构的联系,我们阐释了能量约束扩散过程与各类信息传递网络(如 MLP/GNN/Transformers)的联系,并为新的信息传递设 计提供了⼀种理论参考。基于此,我们提出了⼀种新型的可扩展 Transformer 模型,称为 DIFFormer(diffusionbased Transformers)。它可以作为⼀种通⽤的 encoder,在前向计算中利⽤样本间隐含的依赖关系。⼤量实验表明在⼩ / ⼤图节点分类、图⽚ / ⽂本分类、时空预测等多个领域的实验任务上 DIFFormer 都展现了强⼤的应⽤潜⼒。在计算效率上,DIFFormer 只需要 3GB 显存就可以实现⼗万级样本间全联接的信息传递

动机与背景

我们⾸先回顾⼀个经典的热⼒学中的热传导过程:假设系统中有个节点,每个节点有初始的温度,两两节点之间都存在信号流动,随着时间的推移节点的温度会不断更新。上述物理过程事实上可以类⽐的看作深度神经网络计算样本表征(embedding)的前向过程。

将神经⽹络的前向计算过程看作⼀个扩散过程:每个样本视为流形上的固定位置节点,样本的表征为节点的信号,表征的更新视作节点信号的改变,样本间的信息传递看作节点之间的信号流动


具体的,考虑包含个样本的数据集,用表示样本 i 的输入特征,表示样本 i 的表征向量。⼀个 L 层的神经网络模型会把每个输⼊样本映射到⼀系列隐空间中的表征向量:

这⾥我们可以把每个样本看作⼀个离散空间中的节点,样本表征看作节点的信号。当模型结构考虑样本交互时(如信息传递),它可以被看作节点之间的信号流动,随着模型层数加深(即时间的推移),样本表征会不断被更新。

扩散过程的描述

⼀个经典的扩散过程可以由⼀个热传导⽅程(带初始条件的偏微分⽅程)来描述


这⾥的分别表示梯度(gradient)算⼦、散度 (divergence) 算⼦和扩散率(diffusivity)。对于由 N 个节点组成的离散化空间,以上三个概念的具体定义可以如下表示:

在离散空间中,梯度算⼦可以看作两两节点的信号差异,散度算子可以看作单个节点流出信号的总和,⽽扩散率(diffusivity)是⼀种对任意两两节点间信号流动速率的度量


由此我们可以写出描述 N 个节点每时每刻状态更新的扩散微分⽅程,它描述了每个状态下系统中每个节点信号的变化等于流向其他节点的信号总和:


这⾥的扩散率定义了在当前时刻任意两两节点之间的影响,即信号从节点流向的速率的⼀种度量。
由扩散方程导出的信息传递

我们进⼀步使⽤数值有限差分(具体的这⾥使⽤显式欧拉法)将上述的微分⽅程展开成迭代更新的形式,引⼊⼀个步⻓对连续时间进⾏离散化(再经过⽅程左右重新整理):


这⾥的第⼀项系数可以被视作⼀个常数(如果假设是经过沿⾏归⼀化的),于是上式就可以视为⼀个对其他样本表征的信息聚合(第⼆项)再加上⼀个对上⼀层⾃身表征的 residual 连接(第⼀项)。这⾥的扩散率是⼀个的矩阵,我们可以对其进⾏不同的假设,就可以得到不同模型的层间更新公式:

  • 如果是⼀个的单位矩阵:(1)式中每个样本的表征计算只取决于⾃⼰(与其他样本独⽴),此时给出的是 Multi-Layer Perceptron (MLP) 的更新公式,即每个样本被单独输⼊进 encoder 计算表征;
  • 如果在固定位置存在⾮零值(如输⼊图中存在连边的位置):(1)式中每个样本的表征更新会依赖于图中相邻的其他节点,此时给出的是 Graph Neural Networks (GNN) 的更新公式,其中是传播矩阵(propagation matrix),例如图卷积⽹络(GCN)模型采⽤归⼀化后的邻接矩阵
  • 如果在所有位置都允许有⾮零值,且每层的都可以发⽣变化:(1)式中每个样本的表征更新会依赖于其他所有节点,且每次更新两两节点间的影响也会适应性的变化,此时 (1) 式给出的是 Transformer 结构的更新公式,表示第层的 attention 矩阵。


下图概述了这三种信息传递模式:


我们研究最后⼀种信息传递⽅式,每层更新的样本表征会利⽤上⼀层所有其他样本的表征,在理论上模型的表达能⼒是最强的。但由此产⽣的⼀个问题是:要如何才能确定合适的每层任意两两节点之间的 diffusivity,使得模型能够产⽣理想的样本表征?

刻画⼀致性的能量函数
我们这⾥引⼊⼀个能量函数,来刻画每时每刻由系统中所有节点表征所定义的内在⼀致性,通过能量的最⼩化来引导扩散过程中节点信号的演 变⽅向。具体的,对于样本表征,其对应的能量定义为:


这⾥的第⼀项约束了每个节点对⾃身当前状态的局部⼀致性,第⼆项了约束了与系统中其他节点的全局⼀致性。其中是⼀个单调递增的凹函数(当差别较⼤时,会返回⼀个适中的能量值,即减⼩对差异较⼤的节点对的“惩罚”,这有助于提升样本表征的 diversity)。理想情况下,当系统的整体能量达到最⼩化,我们可以认为系统中的每⼀个个体都与整体取得了平衡,样本的表征同时吸收了局部和全局的信息。


能量约束的扩散过程
基于此,我们考虑⼀种带能量约束的扩散过程,每⼀步的扩散率被定义为⼀个待优化的隐变量,我们希望它给出的每⼀步的节点表征都能够使得系统整体的能量下降。带能量约束的扩散过程可以被形式化的描述为:


虽然直接求解⾮常复杂(因为他耦合了每⼀步能量下降的约束),不过本⽂通过理论分析建⽴了扩散⽅程数值迭代与能量优化梯度更新的等价性,从⽽得到了每⼀步扩散率的最优闭式解。



相关文章
|
SQL 存储 监控
水滴筹基于阿里云 EMR StarRocks 实战分享
水滴筹大数据部门的数据开发工程师韩园园老师为大家分享水滴筹基于阿里云EMR StarRocks的实战经验。
6857 3
水滴筹基于阿里云 EMR StarRocks 实战分享
|
存储 固态存储 安全
阿里云服务器最新价格参考,2024年阿里云服务器活动价格表及收费标准
进入2024年,阿里云服务器的活动价格又降价了,现在购买阿里云服务器年付最低仅需61元即可购买一台2核2G3M带宽的轻量应用服务器,而月付最低只需要30.06元即可购买一台2核4G3M带宽配置的云服务器,另外通用算力型u1实例2核4G、4核8G和8核16G等热门配置的活动价格在2024年也再次下降了,例如2核4G配置1M带宽20G ESSD Entry云盘,现在活动价格只要531.79元/1年了,选择5M带宽现在只要898.99元/1年了,下面是2024年阿里云服务器最新活动价格表。
阿里云服务器最新价格参考,2024年阿里云服务器活动价格表及收费标准
|
缓存 监控 定位技术
鸿蒙5开发宝藏案例分享---一多开发实例(音乐)
鸿蒙开发干货分享:涵盖动态布局、交互动效、服务卡片设计、内存优化、分布式开发及性能加速六大核心主题。从折叠屏适配到手势动画,从服务卡片最佳实践到内存泄漏检测,结合官方100+实战案例与高频痛点解决方案,助你解锁鸿蒙开发隐藏技巧,提升效率与用户体验。快来一起探讨开发中的那些“坑”吧!
|
12月前
|
存储 关系型数据库 分布式数据库
PolarDB开源数据库进阶课15 集成DeepSeek等大模型
本文介绍了如何在PolarDB数据库中接入私有化大模型服务,以实现多种应用场景。实验环境依赖于Docker容器中的loop设备模拟共享存储,具体搭建方法可参考相关系列文章。文中详细描述了部署ollama服务、编译并安装http和openai插件的过程,并通过示例展示了如何使用这些插件调用大模型API进行文本分析和情感分类等任务。此外,还探讨了如何设计表结构及触发器函数自动处理客户反馈数据,以及生成满足需求的SQL查询语句。最后对比了不同模型的回答效果,展示了deepseek-r1模型的优势。
683 3
|
前端开发 Java 应用服务中间件
【Tomcat源码分析 】"深入探索:Tomcat 类加载机制揭秘"
本文详细介绍了Java类加载机制及其在Tomcat中的应用。首先回顾了Java默认的类加载器,包括启动类加载器、扩展类加载器和应用程序类加载器,并解释了双亲委派模型的工作原理及其重要性。接着,文章分析了Tomcat为何不能使用默认类加载机制,因为它需要解决多个应用程序共存时的类库版本冲突、资源共享、类库隔离及JSP文件热更新等问题。最后,详细展示了Tomcat独特的类加载器设计,包括Common、Catalina、Shared、WebApp和Jsp类加载器,确保了系统的稳定性和安全性。通过这种设计,Tomcat实现了不同应用程序间的类库隔离与共享,同时支持JSP文件的热插拔。
【Tomcat源码分析 】"深入探索:Tomcat 类加载机制揭秘"
|
存储 负载均衡 容灾
MySQL数据库的分布式架构和数据分片方案
MySQL数据库的分布式架构和数据分片方案
|
SQL JavaScript 前端开发
Hive学习-lateral view 、explode、reflect和窗口函数
Hive学习-lateral view 、explode、reflect和窗口函数
880 4
BERT+PET方式模型训练(一)
• 本项目中完成BERT+PET模型搭建、训练及应用的步骤如下(注意:因为本项目中使用的是BERT预训练模型,所以直接加载即可,无需重复搭建模型架构): • 一、实现模型工具类函数 • 二、实现模型训练函数,验证函数 • 三、实现模型预测函数
|
前端开发 算法 容器
css【详解】grid布局—— 网格布局(栅格布局)(一)
css【详解】grid布局—— 网格布局(栅格布局)(一)
1928 1
|
安全 数据可视化 Java
BurpSuite
BurpSuite
1204 7

热门文章

最新文章