图神经网络版本的PyTorch来了,Facebook开源GTN框架,还可对图自动微分

2020-10-29 10:27:42 浏览数 (1)


新智元报道

编辑:QJP

【新智元导读】近日,Facebook的AI研究院发表了一篇论文「DIFFERENTIABLE WEIGHTED FINITE-STATE TRANSDUCERS」,开源了用于图网络建模的GTN框架,操作类似于PyTorch这种传统的框架,也可以进行自动微分等操作,大大提高了对图模型建模的效率。

图神经网络「GNN」是近年来最火爆的研究领域之一,常用于社交网络和知识图谱的构建,由于具有良好的可解释性,现在已经广泛使用在各个场景当中。

使用基于图的数据结构构建机器学习模型一直很困难,因为没有很多易于使用的框架。通过将图(或数据)从操作中分离出来,研究人员将有更多的自由和机会来尝试更多的结构化学习算法的设计。

Facebook刚开源的工具,将帮助开发人员更快地开发图相关的算法。

图结构非常适合于编码有用的先验知识,通过在训练时使用这些图,整个系统仍然可以从数据中进行学习和改进。从长远来看,WFST与数据学习相结合有可能使机器学习模型更加精确、模块化和轻量化。

GTN框架:用WFSTs代替Tensor

Facebook近期开源了GTN(Graph Transformer Networks)框架,一个为了图的自动微分而设计的开源框架,支持功能强大、具有表达能力的图结构,称为加权有限状态转换器(WFSTs)

就像PyTorch 为张量的自动微分提供了一个框架一样,GTN 也为WFSTs提供了这样一个框架。AI研究人员和工程师可以使用 GTN 更有效地训练基于图的机器学习模型。

这个框架是用C 编写的,可以通过Python直接安装来使用。

WFST数据结构通常用于结合不同信息源的信息,如存在于语音识别、自然语言处理和手写识别等应用中的信息。

一个标准的语音识别器可能包括一个声学模型和一个语言模型,前者可以预测一个语音片段中出现的字母,后者可以预测一个给定单词跟随另一个单词的可能性。

这些模型可以表示为一个 WFST ,通常会被单独训练并结合起来得到最佳的结果。我们新的 GTN 库使得不同类型的模型一起训练成为可能,从而提供更好的结果。

图比张量更具有结构性,这使得研究人员可以将关于任务的更有用的先验知识编码成一种学习算法。例如,在语音识别中,如果一个单词有几个可能的读音,则GTN 允许我们将该单词的读音编码成一个图,并将该图合并到学习算法中。

以前,在训练时使用单个图是不容易的,开发人员必须硬编码软件中的图结构。现在,使用这个框架,研究人员可以在训练时动态地使用 WFSTs,整个系统可以更有效地从数据中学习和改进。

上图显示使用Graph来构建ASG序列,在「p:r/w」标签中,p表示输入标签,r表示输出标签,w是权重。

GTN工作原理类似PyTorch,简单易上手

通过使用 GTN ,研究人员可以轻松地构建WFST,并将其可视化,在其上执行操作。

通过简单调用「gtn.backward」,可以针对参与计算的任何图计算梯度。下面是一个例子:

GTN 的编程风格与 PyTorch 这样的框架非常相似。命令式样式、 autograd API 和 autograd的实现都是基于类似的设计原则。

主要的区别是我们用 WFSTs 及其相应的操作来替换掉PyTorch中的Tensor。同时与很多框架一样,GTN 的目的是在不牺牲性能的情况下易于使用。

在论文中,作者给出了如何使用 GTN 实现算法的实例。

其中一个例子是使用 GTN 增加序列级的损失函数的能力,将短语分解变成word pieces。模型还可以自由选择如何将单词「The」分解为word pieces,例如,模型可以选择使用「th」和「 e」 ,或者「 t」、「 h」和「 e」。

图:显示了一个简单的内置在 GTN中的WFST,它分解的「the」的word piece转换到单词本身

在机器翻译和语音识别中经常使用word pieces,但是这种分解是从任务无关的模型中选择的,而我们的新方法可以使得模型学习出给定任务的单词或短语的最佳分解方式。

同时,GTN还使用了卷积WFST层,通过在IAM数据集上的实验,卷积核可以把字母转换成200个word piece。所有卷积核的宽度是5,步长为4,输入通道为80,输出通道是200。

上图是WFST卷积层和传统卷积层的对比,可以看出,在参数量和时间复杂度都得到了大幅度降低的同时,性能得到了一定的提升。

如何使用GTN框架

环境要求:

下面是使用GTN构建两个 WFSA的案例:

在图上构造简单的函数,进行前向计算和可视化,并反向求导计算它们的梯度:

下图是使用GTN来计算ASG损失函数和梯度的例子,ASG函数的输入是所有的gtn.Graph对象。

总体来说,这篇论文的贡献在于:

设计了一个框架通过使用WFSTs来对Graph自动求微分,同时支持C 和python。

GTN框架可以用来计算已有的序列级别的损失函数,同时设计了一个全新的序列级别损失函数。

提出了卷积WFST层可以把底层的表征映射到更高级别的表征。

通过实验阐述了使用WFSTs用于语音和手写识别的有效性。

参考链接:

https://ai.facebook.com/blog/a-new-open-source-framework-for-automatic-differentiation-with-graphs/

https://arxiv.org/pdf/2010.01003.pdf

0 人点赞