GraphGallery:几行代码玩转图神经网络

2021-10-12 10:47:47 浏览数 (1)

TensorFlow or PyTorch, both!

本文介绍中山大学图学习团队开发的图神经网络基准模型库GraphGallery,支持多种深度学习框架(PyTorch与TensorFlow)以及两种图神经网络开发后端(PyG与DGL),能够帮助你快速训练和测试图神经网络模型。

1前言

图神经网络(Graph Neural Networks,GNN)是近几年兴起的新的研究热点,其借鉴了传统卷积神经网络等模型的思想,在图结构数据上定义了一种新的神经网络架构。如果作为初入该领域的科研人员,想要快速学习并验证自己的想法,需要花费一定的时间搜集数据集,定义模型的训练测试过程,寻找现有的模型进行比较测试,这无疑是繁琐且不必要的。GraphGallery 为科研人员提供了一个简单方便的框架,用于在一些常用的数据集上快速建立和测试自己的模型,并且与现有的基准模型进行比较。GraphGallery目前支持主流的两大机器学习框架:TensorFlow 和 PyTorch,以及两种图神经网络开发后端PyG与DGL,带你几行代码玩转图神经网络。

GraphGallery项目地址:https://github.com/EdisonLeeeee/GraphGallery

2GraphGallery项目概览

GraphGallery架构图

GraphGallery的架构主要包括输入数据流,模型构建,以及训练测试pipeline,用于对目前现有的GNN模型进行快速搭建。GraphGallery目前实现了节点分类任务主流的图神经网络模型(如GCN,GAT等),以及部分节点嵌入模型(如DeepWalk,Node2Vec等):

论文模型实现列表(截取部分)

3GraphGallery安装及使用

1安装

安装前需要用户自行安装所需版本的PyTorch,其余TensorFlow,PyTorch Geometric与DGL为可选安装项。

  • 直接从源码安装(推荐使用)
代码语言:javascript复制
# Recommended
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
pip install -e . --verbose
  • 从 Pypi 安装(版本更新相对滞后)
代码语言:javascript复制
# Maybe outdated
pip install -U graphgallery

2快速上手

Dataset

以领域内常用的固定划分基准数据集Planetoid为例:

代码语言:javascript复制
from graphgallery.datasets import Planetoid
# set `verbose=False` to avoid informational messages 
data = Planetoid('cora', verbose=False)
graph = data.graph
splits = data.split_nodes() # 使用节点固定的划分
>>> graph
Graph(adj_matrix(2708, 2708),
      node_attr(2708, 1433),
      node_label(2708,),
      metadata=None, multiple=False)

目前包含 6 种数据集

代码语言:javascript复制
>>> data.available_datasets()
Objects in BunchDict:
╒════════════╤═══════════════════════════╕
│ Names      │ Objects                   │
╞════════════╪═══════════════════════════╡
│ citeseer   │ citeseer citation dataset │
├────────────┼───────────────────────────┤
│ cora       │ cora citation dataset     │
├────────────┼───────────────────────────┤
│ pubmed     │ pubmed citation dataset   │
├────────────┼───────────────────────────┤
│ nell.0.1   │ NELL dataset              │
├────────────┼───────────────────────────┤
│ nell.0.01  │ NELL dataset              │
├────────────┼───────────────────────────┤
│ nell.0.001 │ NELL dataset              │
╘════════════╧═══════════════════════════╛

graphgallery.datasets模块还提供了相当多的数据集,具体可查看项目主页:

https://github.com/EdisonLeeeee/GraphGallery

Model Gallery

顾名思义,GraphGallery 是一个GNN模型的 Gallery

GraphGallery 实现了一系列的面向不同下游任务的GNN模型,以最常见的GCN模型与节点分类任务为例

代码语言:javascript复制
from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
trainer.setup_graph(graph)
trainer.build()
trainer.fit(splits.train_nodes, splits.val_nodes)
results = trainer.evaluate(splits.test_nodes)

训练过程如下:

代码语言:javascript复制
Training...
100/100 [==============================] - Total: 6.46s - 64ms/step - loss: 0.081 - accuracy: 0.986 - val_loss: 0.699 - val_accuracy: 0.788
Testing...
1/1 [====================] - Total: 14.41ms - 14ms/step - loss: 1.119 - accuracy: 0.815

上述代码究竟做了哪些事情呢?

  • 第一步(初始化):trainer = GCN()初始化了一个GCN的训练模型,可以传入参数seeddevice设定随机数种子和运行设备
  • 第二步(数据处理):trainer.setup_graph(graph)对输入的图数据进行预处理,并转换为张量用于后续训练
  • 第三步(模型构建):train.build()实现了模型搭建的步骤,build方法可以指定包含隐藏层单元个数(层数),激活函数,学习率等参数
  • 第四步(训练):trainer.fit(splits.train_nodes, splits.val_nodes)实现了对训练集节点的拟合,并利用验证集节点存储模型最优参数
  • 第五步(测试):训练好后,调用trainer.evaluate(splits.test_nodes)在测试集节点上进行验证。result保存了模型测试结果,输出如下:
代码语言:javascript复制
>>> result
Objects in BunchDict:
╒══════════╤═══════════╕
│ Names    │   Objects │
╞══════════╪═══════════╡
│ loss     │   1.11898 │
├──────────┼───────────┤
│ accuracy │   0.815   │
╘══════════╧═══════════╛

至此,只需要几行代码即可完成对一个模型的调用和训练测试,并且当你切换不同的后端,调用的是不同后端实现的模型(甚至不需要更改上述调用代码),例如:

代码语言:javascript复制
import graphgallery
# 修改为TensorFlow后端(需要提前安装好 TensorFlow)
>>> graphgallery.set_backend('tf')
# 修改为PyG后端(需要提前安装好 PyG)
>>> graphgallery.set_backend('pyg')
# 修改为DGL后端(需要提前安装好 DGL)
>>> graphgallery.set_backend('dgl')

当你切换不同的后端,GraphGallery后台会帮你切换模型对应的框架实现(如果有存在模型实现的话),并且不需要修改原先代码,上述的训练代码仍然可以无需修改直接使用:

代码语言:javascript复制
from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
# 预处理,模型构建,训练,测试代码都不需要改变

如果不清楚当前后端及任务所实现的模型列表,可以调用如下API查看(以节点分类任务为例):

代码语言:javascript复制
>>> graphgallery.gallery.nodeclas.models()
Registry of PyTorch-Gallery (Node Classification):
╒════════════╤════════════════════════════════════════════════════════════════════════════╕
│ Names      │ Objects                                                                    │
╞════════════╪════════════════════════════════════════════════════════════════════════════╡
│ GCN        │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.GCN'>                    │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ DenseGCN   │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.DenseGCN'>               │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ GAT        │ <class 'graphgallery.gallery.nodeclas.pytorch.gat.GAT'>                    │
├────────────┼────────────────────────────────────────────────────────────────────────────┤

如上所示,输出的是节点分类任务以及PyTorch后端实现的模型(部分输出结果)。

其它模型

除了主流的基于不同框架实现的图神经网络,GraphGallery还实现了一些常用的无监督节点嵌入模型,如DeepWalk,Node2Vec等。GraphGallery使用Scipy Numpy实现,并采用Numba进行加速,在保证模型性能与原论文相近的同时,大大提高了该方法的速度:

代码语言:javascript复制
from graphgallery.gallery.embedding import DeepWalk
model = DeepWalk()
model.fit(graph.adj_matrix)
embedding = model.get_embedding()

其中,graph.adj_matrix是输入的邻接矩阵(以Scipy.sparse.csr_matrix方式存储)。如上所示,只需几行代码就可以得到最终的结点嵌入。

4后续工作

在实现上,GraphGallery借鉴了许多优秀的开源项目,如:Pytorch Geometric, Stellargraph 和 DGL等。当前, GraphGallery 仍然处于开发阶段,还有许多工作需要完成:

  • 实现更多的 GNN 模型(多种后端)
  • 支持更多的任务(目前主要支持半监督的节点分类任务),未来会加入更多链路预测,图分类等下游任务
  • 支持更多样的图数据结构(目前主要支持单一无向同构图),未来会考虑异构图,动态图等
  • 为项目提供更好的项目文档和注释(完善中...)

最后,附上项目地址及论文:

[1] GraphGallery 项目主页:https://github.com/EdisonLeeeee/GraphGallery [2] Jintang Li, Kun Xu, Liang Chen*, Zibin Zheng and Xiao Liu, “GraphGallery: A Platform for Fast Benchmarking and Easy Development of Graph Neural Networks Based Intelligent Software”, 2021 IEEE/ACM 43rd International Conference on Software Engineering: Companion Proceedings (ICSE-Companion). IEEE, 2021: 13-16

0 人点赞