图神经网络版本的PyTorch来了,Facebook开源GTN框架,还可对图自动微分

简介: 近日,Facebook的AI研究院发表了一篇论文「DIFFERENTIABLE WEIGHTED FINITE-STATE TRANSDUCERS」,开源了用于图网络建模的GTN框架,操作类似于PyTorch这种传统的框架,也可以进行自动微分等操作,大大提高了对图模型建模的效率。

微信图片_20220109165314.png


图神经网络「GNN」是近年来最火爆的研究领域之一,常用于社交网络和知识图谱的构建,由于具有良好的可解释性,现在已经广泛使用在各个场景当中。

 

使用基于图的数据结构构建机器学习模型一直很困难,因为没有很多易于使用的框架。通过将图(或数据)从操作中分离出来,研究人员将有更多的自由和机会来尝试更多的结构化学习算法的设计。

 

微信图片_20220109165317.jpg


Facebook刚开源的工具,将帮助开发人员更快地开发图相关的算法。


图结构非常适合于编码有用的先验知识,通过在训练时使用这些图,整个系统仍然可以从数据中进行学习和改进。从长远来看,WFST与数据学习相结合有可能使机器学习模型更加精确、模块化和轻量化。

 

GTN框架:用WFSTs代替Tensor

 

Facebook近期开源了GTN(Graph Transformer Networks)框架,一个为了图的自动微分而设计的开源框架,支持功能强大、具有表达能力的图结构,称为加权有限状态转换器(WFSTs)

 

就像PyTorch 为张量的自动微分提供了一个框架一样,GTN 也为WFSTs提供了这样一个框架。AI研究人员和工程师可以使用 GTN 更有效地训练基于图的机器学习模型。


       微信图片_20220109165319.png


这个框架是用C++编写的,可以通过Python直接安装来使用。

 

WFST数据结构通常用于结合不同信息源的信息,如存在于语音识别、自然语言处理和手写识别等应用中的信息。

 

微信图片_20220109165322.png


一个标准的语音识别器可能包括一个声学模型和一个语言模型,前者可以预测一个语音片段中出现的字母,后者可以预测一个给定单词跟随另一个单词的可能性。

 

这些模型可以表示为一个 WFST ,通常会被单独训练并结合起来得到最佳的结果。我们新的 GTN 库使得不同类型的模型一起训练成为可能,从而提供更好的结果。

 

图比张量更具有结构性,这使得研究人员可以将关于任务的更有用的先验知识编码成一种学习算法。例如,在语音识别中,如果一个单词有几个可能的读音,则GTN 允许我们将该单词的读音编码成一个图,并将该图合并到学习算法中。

 

以前,在训练时使用单个图是不容易的,开发人员必须硬编码软件中的图结构。现在,使用这个框架,研究人员可以在训练时动态地使用 WFSTs,整个系统可以更有效地从数据中学习和改进。


       微信图片_20220109165324.png


上图显示使用Graph来构建ASG序列,在「p:r/w」标签中,p表示输入标签,r表示输出标签,w是权重。

 

GTN工作原理类似PyTorch,简单易上手

 

通过使用 GTN  ,研究人员可以轻松地构建WFST,并将其可视化,在其上执行操作。

 

通过简单调用「gtn.backward」,可以针对参与计算的任何图计算梯度。下面是一个例子:


       微信图片_20220109165326.png      


GTN 的编程风格与 PyTorch 这样的框架非常相似。命令式样式、 autograd API 和 autograd的实现都是基于类似的设计原则。

 

主要的区别是我们用 WFSTs 及其相应的操作来替换掉PyTorch中的Tensor。同时与很多框架一样,GTN 的目的是在不牺牲性能的情况下易于使用。

 

在论文中,作者给出了如何使用 GTN 实现算法的实例。

 

其中一个例子是使用 GTN 增加序列级的损失函数的能力,将短语分解变成word pieces。模型还可以自由选择如何将单词「The」分解为word pieces,例如,模型可以选择使用「th」和「 e」 ,或者「 t」、「 h」和「 e」。


       微信图片_20220109165329.png      


图:显示了一个简单的内置在 GTN中的WFST,它分解的「the」的word piece转换到单词本身

 

在机器翻译和语音识别中经常使用word pieces,但是这种分解是从任务无关的模型中选择的,而我们的新方法可以使得模型学习出给定任务的单词或短语的最佳分解方式。


微信图片_20220109165332.png


同时,GTN还使用了卷积WFST层,通过在IAM数据集上的实验,卷积核可以把字母转换成200个word piece。所有卷积核的宽度是5,步长为4,输入通道为80,输出通道是200。


微信图片_20220109165335.png


上图是WFST卷积层和传统卷积层的对比,可以看出,在参数量和时间复杂度都得到了大幅度降低的同时,性能得到了一定的提升。

 

如何使用GTN框架

 

环境要求:


      微信图片_20220109165338.png      


下面是使用GTN构建两个 WFSA的案例:


       微信图片_20220109165340.png


在图上构造简单的函数,进行前向计算和可视化,并反向求导计算它们的梯度:


       微信图片_20220109165342.png      


下图是使用GTN来计算ASG损失函数和梯度的例子,ASG函数的输入是所有的gtn.Graph对象。


       微信图片_20220109165345.png  


总体来说,这篇论文的贡献在于:

 

设计了一个框架通过使用WFSTs来对Graph自动求微分,同时支持C++和python。

 

GTN框架可以用来计算已有的序列级别的损失函数,同时设计了一个全新的序列级别损失函数。

 

提出了卷积WFST层可以把底层的表征映射到更高级别的表征。

 

通过实验阐述了使用WFSTs用于语音和手写识别的有效性。


参考链接:

https://ai.facebook.com/blog/a-new-open-source-framework-for-automatic-differentiation-with-graphs/

https://arxiv.org/pdf/2010.01003.pdf

相关实践学习
达摩院智能语音交互 - 声纹识别技术
声纹识别是基于每个发音人的发音器官构造不同,识别当前发音人的身份。按照任务具体分为两种: 声纹辨认:从说话人集合中判别出测试语音所属的说话人,为多选一的问题 声纹确认:判断测试语音是否由目标说话人所说,是二选一的问题(是或者不是) 按照应用具体分为两种: 文本相关:要求使用者重复指定的话语,通常包含与训练信息相同的文本(精度较高,适合当前应用模式) 文本无关:对使用者发音内容和语言没有要求,受信道环境影响比较大,精度不高 本课程主要介绍声纹识别的原型技术、系统架构及应用案例等。 讲师介绍: 郑斯奇,达摩院算法专家,毕业于美国哈佛大学,研究方向包括声纹识别、性别、年龄、语种识别等。致力于推动端侧声纹与个性化技术的研究和大规模应用。
相关文章
|
28天前
|
PyTorch Linux 算法框架/工具
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
这篇文章是关于如何使用Anaconda进行Python环境管理,包括下载、安装、配置环境变量、创建多版本Python环境、安装PyTorch以及使用Jupyter Notebook的详细指南。
217 1
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
|
18天前
|
机器学习/深度学习 人工智能
类人神经网络再进一步!DeepMind最新50页论文提出AligNet框架:用层次化视觉概念对齐人类
【10月更文挑战第18天】这篇论文提出了一种名为AligNet的框架,旨在通过将人类知识注入神经网络来解决其与人类认知的不匹配问题。AligNet通过训练教师模型模仿人类判断,并将人类化的结构和知识转移至预训练的视觉模型中,从而提高模型在多种任务上的泛化能力和稳健性。实验结果表明,人类对齐的模型在相似性任务和出分布情况下表现更佳。
45 3
|
1月前
|
网络协议 物联网 虚拟化
|
28天前
|
机器学习/深度学习 缓存 PyTorch
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
这篇文章是关于如何下载、安装和配置Miniconda,以及如何使用Miniconda创建和管理Python环境的详细指南。
319 0
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
|
1月前
|
并行计算 Ubuntu 算法
Ubuntu18 服务器 更新升级CUDA版本 pyenv nvidia ubuntu1804 原11.2升级到PyTorch要求12.1 全过程详细记录 apt update
Ubuntu18 服务器 更新升级CUDA版本 pyenv nvidia ubuntu1804 原11.2升级到PyTorch要求12.1 全过程详细记录 apt update
90 0
|
1月前
|
并行计算 开发工具 异构计算
在Windows平台使用源码编译和安装PyTorch3D指定版本
【10月更文挑战第6天】在 Windows 平台上,编译和安装指定版本的 PyTorch3D 需要先安装 Python、Visual Studio Build Tools 和 CUDA(如有需要),然后通过 Git 获取源码。建议创建虚拟环境以隔离依赖,并使用 `pip` 安装所需库。最后,在源码目录下运行 `python setup.py install` 进行编译和安装。完成后即可在 Python 中导入 PyTorch3D 使用。
146 0
|
2月前
|
JSON 监控 编译器
|
4月前
|
存储 Prometheus 监控
|
3月前
|
Rust 监控 Linux
这款开源网络监控工具(sniffnet),太实用了!
这款开源网络监控工具(sniffnet),太实用了!
|
4月前
|
网络协议 安全 Shell
`nmap`是一个开源的网络扫描工具,用于发现网络上的设备和服务。Python的`python-nmap`库允许我们在Python脚本中直接使用`nmap`的功能。
`nmap`是一个开源的网络扫描工具,用于发现网络上的设备和服务。Python的`python-nmap`库允许我们在Python脚本中直接使用`nmap`的功能。