今天学习的是斯坦福大学的同学 2018 年的工作《Hierarchical Graph Representation Learning with Differentiable Pooling》,目前共有 140 多次引用。
目前,GNN 在图分类任务中的处理方法本质上是平面的(Flat),无法学习图形的层次化表达。对于一个包含多个标签的图来说,传统的方法都是为图中每个节点生成一个 Embedding 向量,然后利用这些 Embedding 向量来做全局池化或者输入到 MLP 中来预测图标签,但这种方法忽视了图的层次结构。
为此,作者提出了一个可微分的图池化模块——Diff Pool,用于完成图的层次化表达,并可以端到端的方式与目前多种模型相结合。
Diff Pool 为每一层的节点都学习了一个可微的软聚类,聚类后的集群作为下一层的输入,最后通过连接 MLP 完成图分类任务。
与现有方法相比,Diff Pool 在五个基准数据集中取得了 SOTA 的成绩。
1.Introduction
目前通用的做法是采用 GNN 来提取用于图分类的特征,基于消息传递的 GNN 架构为:
其中,
为第 k 层节点的 Embedding 特征;M 为消息传递函数,与邻接矩阵和参数
相关。
消息传递函数由多种可能的实现,比如说 Kipf 大佬的 GCN 使用的是线性变换和 ReLU 激活函数的组合:
但这种实现其本质是平面的,信息只能够通过边进行传播,无法以分层方式进行推断和汇总信息。
比如说,由于图分类的目标是预测与整个图相关的标签,学者们通常做法是通过编码原子和键编码有机分子的图结构,然后在进行分类。但这种方法会忽视图的层次化结构,损失大量相关信息,从而影响模型效果。
相比于普通的图粗化任务来说,为 GNN 设计这样的池化层具有非常大的挑战,因为其目标不再是简单的为图中的节点进行聚类,而是提供一个通用方法来对图中具有不同连接方式的节点进行分层池化。也就是说,我们需要一种可以自适应的不同图结构池化策略。
2.Diff Pool
本文作者提出了可微的池化模块(Diff Pool),可应用于不同 GNN 模型中。Diff Pool 与 CNN 中的池化不同的是,前者不包含空间局部的概念,且每次 pooling 所包含的节点数和边数都不相同。
Diff Pool 在 GNN 的每一层上都会基于节点的 Embedding 向量进行软聚类,通过反复堆叠(Stacking)建立深度 GNN。因此,Diff Pool 的每一层都能使得图越来越粗糙,从而可以产生输入图的层级表征。具体过程如下图所示:
作者使用分配矩阵来进行池化。具体来说,设 l 层的聚类分配矩阵为
,矩阵行表示 l 层的中
个节点,列表示 l 1 层中的
个节点,
表示从 l 层到 l 1 层的图节点的软聚类结果。
考虑 l 层的邻接矩阵
和节点特征矩阵
,Diff Pool 每一层的粗化图为:
其中,
为下一层粗化图的节点输入特征。
我们将上式分成两部:
第一个式子是特征聚类,第二个式子是转换邻接矩阵。
接下来,我们讨论下如何完成上述目标。
首先,我们会使用两个独立的 GNN 分别学习不同的任务,一个是 Embedding GNN 用于学习 l 层的节点特征:
另一个是池化 GNN 用于学习 l 层到 l 1 层的聚类分配矩阵:
softmax 每个输出矩阵都会应用到,用于学习聚类分配矩阵。
注意,这两个 GNN 输入都相同,但是参数不同,并且任务也不同。
其次,置换不变形对于图分类任务来说非常重要,池化层应该满足置换不变形。对于 Diff Pool 来说,作者表明只要 GNN 组件满足置换不变形,那么整体就会满足,即,若:
那么:
再者,在训练过程中,如果仅从图分类的角度来训练池化 GNN 是非常困难的,因为这是一个非凸优化问题。为了缓解这个问题,作者在训练过程中加上了一个链接预测的目标,从而促使邻近节点一起池化。
每一层中,我们最小化:
其中
表示 Frobenius 范式。
最后,预测簇时的 softmax 结果应该更接近与 one-hot 向量,这样才能使得聚类结果更加清晰。所以我们还需要正则化聚类簇的熵:
其中,H 为熵函数。
训练时,将
和
加入到分类损失中,收敛速度会变慢,但是效果会变好,并且聚类的结果也更有解释性。
3.Experiment
简单看下实验。
下表为不同数据集下不同模型的实验结果:
在 Struct2Vec 的基础上应用 Diff Pool:
可视化 Diff Pool 的聚类结果:
4.Conclusion
作者引入了一种可微的池化方法,该方法能够基于网络图自适应的学习提取复杂的层次结构。Diff Pool 可以现有的 GNN 模型结合使用,并在多个基准数据集上取得了不错的成绩。
时间复杂度为
5.Reference
- Ying Z, You J, Morris C, et al. Hierarchical graph representation learning with differentiable pooling[C]//Advances in neural information processing systems. 2018: 4800-4810.