【GNN】MPNN:消息传递神经网络

2020-07-21 11:24:45 浏览数 (1)

今天学习的是谷歌大脑的同学 2017 年的工作《Neural Message Passing for Quantum Chemistry》,也就是我们经常提到的消息传递神经网络(Message Passing Neural Network,MPNN),目前引用数超过 900 次。

严格来说,MPNN 不是一个模型,而是一个框架。作者在这篇论文中主要将现有模型抽象其共性并提出成 MPNN 框架,同时利用 MPNN 框架在分子分类预测中取得了一个不错的成绩。

1.Introduction

深度学习被广泛应用于图像、音频、NLP 等领域,但在化学任务(分子分类等)中仍然使用中机器学习 特征工程的方式,其主要原因在于目前尚未有工作证明深度学习在这个领域能取得很大的成功。

近年来,随着量子化学计算和分子动力学模拟等实验的展开产生了巨大的数据量,大多数经典的技术都无法有效利用目前的大数据集。而原子系统的对称性表明,能够应用于网络图中的神经网络也能够应用于分子模型。所以,找到一个更加强大的模型来解决目前的化学任务可以等价于找到一个适用于网络的模型。

在这篇论文中,作者的目标是证明:「能够应用于化学预测任务的模型可以直接从分子图中学习到分子的特征,并且不受到图同构的影响」。为此,作者将应用于图上的监督学习框架称之为消息传递神经网络(MPNN),这种框架是从目前比较流行的支持图数据的神经网络模型中抽象出来的一些共性,抽象出来的目的在于理解它们之间的关系。

鉴于目前已经有很多类似 MPNN 框架的模型,所以作者呼吁学者们应该将这个方法应用到实际的应用中,并且通过实际的应用来提出模型的改进版本,尽可能的去推广模型的实际应用。

本文给出的一个例子是利用 MPNN 框架代替计算代价昂贵的 DFT 来预测有机分子的量子特性:

2.MPNN

本节内容分为两块,一块是看下作者如何从现有模型中抽象出 MPNN 框架,另一块是看下作者如何利用 MPNN 框架去解决实际问题。

2.1 MPNN framework

我们先来介绍下 MPNN 这一通用框架,并通过八篇文献来举例验证 MPNN 框架的通配性。

简单起见,我们考虑无向图 G,节点 v 的特征为

x_v

,边的特征为

e_{vw}

。前向传递有两个阶段:一个是「消息传递阶段」(Message Passing),另一个是「读出阶段」(Readout)。考虑消息传递阶段,消息函数定义为

M_t

,顶点更新函数定义为

U_t

,t 为运行的时间步。在消息传递过程中,隐藏层节点 v 的状态

h_v^t

可以被基于

m_v^{t 1}

进行更新:

begin{aligned} m_v^{t 1} &= sum_{win N(v)}M_t(h_v^t, h_w^t,e_{vw}) \ h_v^{t 1} &= U_t(h_v^t, m_v^{t 1}) end{aligned} \

其中,

N(v)

表示图 G 中节点 v 的邻居。

读出阶段使用一个读出函数 R 来计算整张图的特征向量:

hat y = R({h_v^T | v in G}) \

消息函数

M_t

,向量更新函数

U_t

和读出函数

R

都是可微函数。

R

作用于节点的状态集合,同时对节点的排列不敏感,这样才能保证 MPNN 对图同构保持不变。

此外,我们也可以通过引入边的隐藏层状态来学习图中的每一条边的特征,并且同样可以用上面的等式进行学习和更新。

接下来我们看下如何通过定义「消息函数」「更新函数」「读出函数」来适配不同种模型。

「Paper 1」 : Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)

这篇论文中消息函数为:

M(h_v, h_w,e_{vw}) = (h_w,e_{vw}) \

其中

(.,.)

表示拼接(concat);

节点的更新函数为:

U_t(h_v^t,m_v^{t 1}) = sigma(H_t^{deg(v)}m_v^{t 1}) \

其中

sigma

为 sigmoid 函数,

deg(v)

表示节点 v 的度,

H_t^v

是一个可学习的矩阵,t 为时间步,N 为节点度;

读出函数 R 将先前所有隐藏层的状态

h_v^t

进行连接:

R = f(sum_{v,t}softmax(W_th_v^t)) \

其中 f 是一个神经网络,

W_t

是一个可学习的读出矩阵。

这种消息传递阶段可能会存在一些问题,比如说最终的消息向量分别对连通的节点和连通的边求和

m_v^{t 1}=(sum h_w^t,sum e_{vw})

。由此可见,该模型实现的消息传递无法识别节点和边之间的相关性。

「Paper 2」 : Gated Graph Neural Networks (GG-NN), Li et al. (2016)

这篇论文比较有名,作者后续也是在这个模型的基础上进行改进的。

GG-NN 使用的消息函数为:

M_t(h_v^t,h_w^t,e_{vw})=A_{e_{vw}}h_w^t \

其中

A_{e_{vw}}

e_{vw}

的一个可学习矩阵,每条边都会对应那么一个矩阵;

更新函数为:

U_t(h_v^t,m_v^{t 1}) = GRU(h_v^t, m_v^{t 1}) \

其中

GRU

为门控制单元(Gate Recurrent Unit)。该工作使用了权值捆绑,所以在每一个时间步 t 下都会使用相同的更新函数;

读出函数 R 为:

R=sum_{vin V} sigma(i(h_v^{(T)}),h_v^0); odot ; (j(h_v^{(T)})) \

其中 i 和 j 为神经网络,

odot

表示元素相乘。

「Paper 3」 : Interaction Networks, Battaglia et al. (2016)

这篇论文考虑图中的节点和图结构,同时也考虑每个时间步下的节点级的影响。这种情况下更新函数的输入会多一些

(h_v,x_v,m_v)

,其中

x_v

是一个外部向量,表示对顶点 v 的一些外部影响。

这篇论文的消息函数

M(h_v,h_w,e_{vw})

是一个以

(h_v,h_w,e_{vw})

为输入的神经网络,节点更新函数

U(h_v,x_v,m_v)

是一个以

(h_v,x_v,m_v)

为输入的神经网络,最终会有一个图级别的输出

R=f(sum_{vin G}h_v^T)

,其中 f 是一个神经网络,输入是最终的隐藏层状态的和。在原论文中

T=1

「Paper 4」 : Molecular Graph Convolutions, Kearnes et al. (2016)

这篇论文与其他 MPNN 稍微有些不同,主要区别在于考虑了边表示

e_{v,w}^t

,并且在消息传递阶段会进行更新。

消息传递函数用的是节点的消息:

M_t(h_v^t,h_w^t,e_{vw}^t)=e_{vw}^t

节点的更新函数为:

U_t(h_v^t,m_v^{t 1}) = alpha(W_1(alpha(W_0h_v^t),m_v^{t 1}))

其中

(.,.)

表示拼接(concat),

alpha

为 ReLU 激活函数,

W_0,W_1

为可学习权重矩阵;

边状态的更新定义为:

begin{aligned} e_{vw}^{t 1} &= U_t^{'}(e_{vw}^t, h_v^t, h_w^t) \ &= alpha(W_4(alpha (W_2,e_{vw}^t), alpha(W_3(h_v^t,h_w^t)))) end{aligned} \

其中,

W_i

为可学习权重矩阵。

「Paper 5」 : Deep Tensor Neural Networks, Schutt et al. (2017)

消息函数为:

M_t = tanh(W^{fc}((W^{cf}h_w^t b_1) odot(W^{df}e_{vw} b_2))) \

其中

W^{fc},W^{cf},W^{df}

为矩阵,

b_1,b_2

为偏置向量;

更新函数为:

U_t(h_v^t,m_v^{t 1}) = h_v^t m_v^{t 1} \

读出函数通过单层隐藏层接受每个节点并且求和后输出:

R = sum_v NN(h_v^T) \

「Paper 6」 : Laplacian Based Methods, Bruna et al. (2013); Defferrard et al. (2016); Kipf & Welling (2016)

基于拉普拉斯矩阵的方法将图像中的卷积运算扩展到网络图 G 的邻接矩阵 A 中。

在 Bruna et al. (2013); Defferrard et al. (2016); 的工作中,消息函数为:

M_t(h_v^t,h_w^t) = C_{vw}^t h_w^t \

其中,矩阵

C_{vw}^t

为拉普拉斯矩阵 L 的特征向量组成的矩阵;

节点的更新函数为:

U_t(h_v^t, m_v^{t 1}) = sigma(m_v^{t 1}) \

其中,

sigma

为非线性的激活函数,比如说 ReLU。

在 Kipf & Welling (2016) 的工作中,消息函数为:

M_t(h_v^t,h_w^t) = C_{vw} h_w^t \

其中,

C_{vw} = (deg(v)deg(w))^{-1/2}A_{vw}

节点的更新函数为:

U_v^t(h_v^t, m_v^{t 1}) = ReLU(W^t m_v^{t 1}) \

可以看到以上模型都是 MPNN 框架的不同实例,所以作者呼吁大家应该致力于将这一框架应用于某个实际应用,并根据不同情况对关键部分进行修改,从而引导模型的改进,这样才能最大限度的发挥模型的能力。

2.2 MPNN Variants

本节来介绍下作者将 MPNN 框架应用于分子预测领域,提出了 MPNN 的变种,并以 QM9 数据集为例进行了实验。

QM9 数据集中的分子大多数由碳氢氧氮等元素组成,并组成了约 134k 个有机分子,可以划分为四大类(具体类别不介绍了),任务是根据分子结构预测分子所属类别。

作者主要是基于 GG-NN 来探索 MPNN 的多种改进方式(不同的消息函数、输出函数等),之所以用 GG-NN 是因为这是一个很强的 baseline。

2.2.1 Message Functions

首先来看下消息函数,可以以 GG-NN 中使用的消息函数开始,GG-NN 用的是矩阵乘法:

M(h_v,h_w,e_{vw}) = A_{e_{vw}}h_w \

为了兼容边特征,作者提出了新的消息函数:

M(h_v,h_w,e_{vw}) = A(e_{vw})h_w \

其中,

A(e_{vw})

是将边的向量

e_{vw}

映射到 d×d 维矩阵的神经网络。

矩阵乘法有一个特点,从节点 w 到节点 v 的函数仅与隐藏层状态

h_w

和边向量

e_{vw}

有关,而和隐藏状态

h_v^t

无关。理论上来说,如果节点消息同时依赖于源节点 w 和目标节点 v 的话,网络的消息通道将会得到更有效的利用。所以也可以尝试去使用一种消息函数的变种:

m_{vw} = f(h_w^t, h_v^t, e_{vw}) \

其中,f 为神经网络。

2.2.2 Virtual Graph Elements

其次看来下消息传递,作者探索了两种不同的消息传递方式。

最简单的修改就是为没有连接的节点添加一个虚拟的边,这样消息便具有了更长的传播距离;

此外,作者也尝试了使用潜在的“主”节点(master node),这个节点可以通过特殊的边来连接到图中任意一个节点。主节点充当了一个全局的暂存空间,每个节点都会在消息传递过程中通过主节点进行读取和写入。同时允许主节点具有自己的节点维度,以及内部更新函数(GRU)的单独权重。其目的同样是为了在传播阶段传播很长的距离。

2.2.3 Readout Functions

然后来看下读出函数,作者同样尝试了两种读出函数:

首先是 GG-NN 中的读出函数:

R=sum_{vin V} sigma(i(h_v^{(T)}),h_v^0); odot ; (j(h_v^{(T)})) \

此外也考虑 set2set 模型。set2set 模型是专门为在集合运算而设计的,并且相比简单累加节点的状态来说具有更强的表达能力。模型首先通过线性映射将数据映射到元组

(h_v^t, x_v)

,并将投影元组作为输入

T={(h_v^T,x_v) }

,然后经过 M 步计算后,set2set 模型会生成一个与节点顺序无关的 Graph-level 的 embeedding 向量,从而得到我们的输出向量。

2.2.4 Multiple Towers

最后考虑下 MPNN 的伸缩性。

对一个稠密图来说,消息传递阶段的每一个时间步的时间复杂度为

O(n^2d^2)

,其中 n 为节点数,d 为向量维度,可以看到时间复杂度还是非常高的。

为了解决这个问题作者将向量维度 d 拆分成 k 份,就变成了 k 个 d/k 维向量,并在传播过程中每个子向量分别进行传播和更新,最后再进行合并。此时的子向量时间复杂度为

O(n^2(d/k)^2)

,考虑 k 个子向量的时间复杂度为

O(n^2d^2/k)

2.3 Input Representation

这一节主要介绍 GNN 的输入。

对于分子来说有很多可以提取的特征,比如说原子组成、化学键等,详细的特征列表如下图所示:

对于邻接矩阵,作者模型尝试了三种边表示形式:

「化学图」(Chemical Graph):在不考虑距离的情况下,邻接矩阵的值是离散的键类型:单键,双键,三键或芳香键;

「距离分桶」(Distance bins):基于矩阵乘法的消息函数的前提假设是「边信息是离散的」,因此作者将键的距离分为 10 个 bin,比如说 [2,6] 中均匀划分 8 个 bin,[0,2] 为 1 个 bin,[6, ∞] 为 1 个 bin;

「原始距离特征」(Raw distance feature):也可以同时考虑距离和化学键的特征,这时每条边都有自己的特征向量,此时邻接矩阵的每个实例都是一个 5 维向量,第一维是距离,其余思维是四种不同的化学键。

4.Experiment

来看一下实验结果,以 QM-9 数据集为例,共包含 130462 个分子,以 MAE 为评估指标。

下图为现有算法和作者改进的算法之间的对比:

下图为不考虑空间信息的结果:

下图为考虑多塔模型和结果:

5.Conclusion

总结:作者从诸多模型中抽离出了 MPNN 框架,并且通过实验表明,具有消息函数、更新函数和读出函数的 MPNN 具有良好的归纳能力,可以用于预测分析特性,优于目前的 Baseline,并且无需进行复杂的特征工程。此外,实验结果也揭示了全局主节点和利用 set2set 模型的重要性,多塔模型也使得 MPNN 更具伸缩性,方便应用于大型图中。

6.Reference

  1. 《Neural Message Passing for Quantum Chemistry》

关注公众号跟踪最新内容:「阿泽的学习笔记」

0 人点赞