TensorFlow or PyTorch, both!
本文介绍中山大学图学习团队开发的图神经网络基准模型库GraphGallery,支持多种深度学习框架(PyTorch与TensorFlow)以及两种图神经网络开发后端(PyG与DGL),能够帮助你快速训练和测试图神经网络模型。
1前言
图神经网络(Graph Neural Networks,GNN)是近几年兴起的新的研究热点,其借鉴了传统卷积神经网络等模型的思想,在图结构数据上定义了一种新的神经网络架构。如果作为初入该领域的科研人员,想要快速学习并验证自己的想法,需要花费一定的时间搜集数据集,定义模型的训练测试过程,寻找现有的模型进行比较测试,这无疑是繁琐且不必要的。GraphGallery 为科研人员提供了一个简单方便的框架,用于在一些常用的数据集上快速建立和测试自己的模型,并且与现有的基准模型进行比较。GraphGallery目前支持主流的两大机器学习框架:TensorFlow 和 PyTorch,以及两种图神经网络开发后端PyG与DGL,带你几行代码玩转图神经网络。
GraphGallery项目地址:https://github.com/EdisonLeeeee/GraphGallery
2GraphGallery项目概览
GraphGallery架构图
GraphGallery的架构主要包括输入数据流,模型构建,以及训练测试pipeline,用于对目前现有的GNN模型进行快速搭建。GraphGallery目前实现了节点分类任务主流的图神经网络模型(如GCN,GAT等),以及部分节点嵌入模型(如DeepWalk,Node2Vec等):
论文模型实现列表(截取部分)
3GraphGallery安装及使用
1安装
安装前需要用户自行安装所需版本的PyTorch,其余TensorFlow,PyTorch Geometric与DGL为可选安装项。
- 直接从源码安装(推荐使用)
# Recommended
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
pip install -e . --verbose
- 从 Pypi 安装(版本更新相对滞后)
# Maybe outdated
pip install -U graphgallery
2快速上手
Dataset
以领域内常用的固定划分基准数据集Planetoid
为例:
from graphgallery.datasets import Planetoid
# set `verbose=False` to avoid informational messages
data = Planetoid('cora', verbose=False)
graph = data.graph
splits = data.split_nodes() # 使用节点固定的划分
>>> graph
Graph(adj_matrix(2708, 2708),
node_attr(2708, 1433),
node_label(2708,),
metadata=None, multiple=False)
目前包含 6 种数据集
代码语言:javascript复制>>> data.available_datasets()
Objects in BunchDict:
╒════════════╤═══════════════════════════╕
│ Names │ Objects │
╞════════════╪═══════════════════════════╡
│ citeseer │ citeseer citation dataset │
├────────────┼───────────────────────────┤
│ cora │ cora citation dataset │
├────────────┼───────────────────────────┤
│ pubmed │ pubmed citation dataset │
├────────────┼───────────────────────────┤
│ nell.0.1 │ NELL dataset │
├────────────┼───────────────────────────┤
│ nell.0.01 │ NELL dataset │
├────────────┼───────────────────────────┤
│ nell.0.001 │ NELL dataset │
╘════════════╧═══════════════════════════╛
graphgallery.datasets
模块还提供了相当多的数据集,具体可查看项目主页:
https://github.com/EdisonLeeeee/GraphGallery
Model Gallery
顾名思义,GraphGallery 是一个GNN模型的 Gallery。
GraphGallery 实现了一系列的面向不同下游任务的GNN模型,以最常见的GCN模型与节点分类任务为例
代码语言:javascript复制from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
trainer.setup_graph(graph)
trainer.build()
trainer.fit(splits.train_nodes, splits.val_nodes)
results = trainer.evaluate(splits.test_nodes)
训练过程如下:
代码语言:javascript复制Training...
100/100 [==============================] - Total: 6.46s - 64ms/step - loss: 0.081 - accuracy: 0.986 - val_loss: 0.699 - val_accuracy: 0.788
Testing...
1/1 [====================] - Total: 14.41ms - 14ms/step - loss: 1.119 - accuracy: 0.815
上述代码究竟做了哪些事情呢?
- 第一步(初始化):
trainer = GCN()
初始化了一个GCN的训练模型,可以传入参数seed
和device
设定随机数种子和运行设备 - 第二步(数据处理):
trainer.setup_graph(graph)
对输入的图数据进行预处理,并转换为张量用于后续训练 - 第三步(模型构建):
train.build()
实现了模型搭建的步骤,build
方法可以指定包含隐藏层单元个数(层数),激活函数,学习率等参数 - 第四步(训练):
trainer.fit(splits.train_nodes, splits.val_nodes)
实现了对训练集节点的拟合,并利用验证集节点存储模型最优参数 - 第五步(测试):训练好后,调用
trainer.evaluate(splits.test_nodes)
在测试集节点上进行验证。result
保存了模型测试结果,输出如下:
>>> result
Objects in BunchDict:
╒══════════╤═══════════╕
│ Names │ Objects │
╞══════════╪═══════════╡
│ loss │ 1.11898 │
├──────────┼───────────┤
│ accuracy │ 0.815 │
╘══════════╧═══════════╛
至此,只需要几行代码即可完成对一个模型的调用和训练测试,并且当你切换不同的后端,调用的是不同后端实现的模型(甚至不需要更改上述调用代码),例如:
代码语言:javascript复制import graphgallery
# 修改为TensorFlow后端(需要提前安装好 TensorFlow)
>>> graphgallery.set_backend('tf')
# 修改为PyG后端(需要提前安装好 PyG)
>>> graphgallery.set_backend('pyg')
# 修改为DGL后端(需要提前安装好 DGL)
>>> graphgallery.set_backend('dgl')
当你切换不同的后端,GraphGallery后台会帮你切换模型对应的框架实现(如果有存在模型实现的话),并且不需要修改原先代码,上述的训练代码仍然可以无需修改直接使用:
代码语言:javascript复制from graphgallery.gallery.nodeclas import GCN
trainer = GCN()
# 预处理,模型构建,训练,测试代码都不需要改变
如果不清楚当前后端及任务所实现的模型列表,可以调用如下API查看(以节点分类任务为例):
代码语言:javascript复制>>> graphgallery.gallery.nodeclas.models()
Registry of PyTorch-Gallery (Node Classification):
╒════════════╤════════════════════════════════════════════════════════════════════════════╕
│ Names │ Objects │
╞════════════╪════════════════════════════════════════════════════════════════════════════╡
│ GCN │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.GCN'> │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ DenseGCN │ <class 'graphgallery.gallery.nodeclas.pytorch.gcn.DenseGCN'> │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
│ GAT │ <class 'graphgallery.gallery.nodeclas.pytorch.gat.GAT'> │
├────────────┼────────────────────────────────────────────────────────────────────────────┤
如上所示,输出的是节点分类任务以及PyTorch后端实现的模型(部分输出结果)。
其它模型
除了主流的基于不同框架实现的图神经网络,GraphGallery还实现了一些常用的无监督节点嵌入模型,如DeepWalk,Node2Vec等。GraphGallery使用Scipy Numpy实现,并采用Numba进行加速,在保证模型性能与原论文相近的同时,大大提高了该方法的速度:
代码语言:javascript复制from graphgallery.gallery.embedding import DeepWalk
model = DeepWalk()
model.fit(graph.adj_matrix)
embedding = model.get_embedding()
其中,graph.adj_matrix
是输入的邻接矩阵(以Scipy.sparse.csr_matrix
方式存储)。如上所示,只需几行代码就可以得到最终的结点嵌入。
4后续工作
在实现上,GraphGallery借鉴了许多优秀的开源项目,如:Pytorch Geometric, Stellargraph 和 DGL等。当前, GraphGallery 仍然处于开发阶段,还有许多工作需要完成:
- 实现更多的 GNN 模型(多种后端)
- 支持更多的任务(目前主要支持半监督的节点分类任务),未来会加入更多链路预测,图分类等下游任务
- 支持更多样的图数据结构(目前主要支持单一无向同构图),未来会考虑异构图,动态图等
- 为项目提供更好的项目文档和注释(完善中...)
最后,附上项目地址及论文:
[1] GraphGallery 项目主页:https://github.com/EdisonLeeeee/GraphGallery [2] Jintang Li, Kun Xu, Liang Chen*, Zibin Zheng and Xiao Liu, “GraphGallery: A Platform for Fast Benchmarking and Easy Development of Graph Neural Networks Based Intelligent Software”, 2021 IEEE/ACM 43rd International Conference on Software Engineering: Companion Proceedings (ICSE-Companion). IEEE, 2021: 13-16