0. PGL图学习之图神经网络GraphSAGE、GIN图采样算法系列七
本项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5061984?contributionType=1
相关项目参考:更多资料见主页
关于图计算&图学习的基础知识概览:前置知识点学习(PGL)系列一 https://aistudio.baidu.com/aistudio/projectdetail/4982973?contributionType=1
图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二):https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1
在图神经网络中,使用的数据集可能是亿量级的数据,而由于GPU/CPU资源有限无法一次性全图送入计算资源,需要借鉴深度学习中的mini-batch思想。
传统的深度学习mini-batch训练每个batch的样本之间无依赖,多层样本计算量固定;而在图神经网络中,每个batch中的节点之间互相依赖,在计算多层时会导致计算量爆炸,因此引入了图采样的概念。
GraphSAGE也是图嵌入算法中的一种。在论文Inductive Representation Learning on Large Graphs 在大图上的归纳表示学习中提出。github链接和官方介绍链接。
与node2vec相比较而言,node2vec是在图的节点级别上进行嵌入,GraphSAGE则是在整个图的级别上进行嵌入。之前的网络表示学习的transductive,难以从而提出了一个inductive的GraphSAGE算法。GraphSAGE同时利用节点特征信息和结构信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射后的结果,而GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出。
0.1提出背景
现存的方法需要图中所有的顶点在训练embedding的时候都出现;这些前人的方法本质上是transductive,不能自然地泛化到未见过的顶点。文中提出了GraphSAGE,是一个inductive的框架,可以利用顶点特征信息(比如文本属性)来高效地为没有见过的顶点生成embedding。GraphSAGE是为了学习一种节点表示方法,即如何通过从一个顶点的局部邻居采样并聚合顶点特征,而不是为每个顶点训练单独的embedding。
这个算法在三个inductive顶点分类benchmark上超越了那些很强的baseline。文中基于citation和Reddit帖子数据的信息图中对未见过的顶点分类,实验表明使用一个PPI(protein-protein interactions)多图数据集,算法可以泛化到完全未见过的图上。
0.2 回顾GCN及其问题
在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。
- GCN虽然能提取图中顶点的embedding,但是存在一些问题:
- GCN的基本思想: 把一个节点在图中的高纬度邻接信息降维到一个低维的向量表示。
- GCN的优点: 可以捕捉graph的全局信息,从而很好地表示node的特征。
- GCN的缺点: Transductive learning的方式,需要把所有节点都参与训练才能得到node embedding,无法快速得到新node的embedding。1.图采样算法1.1 GraphSage: Representation Learning on Large Graphs
图采样算法:顾名思义,图采样算法就是在一张图中进行采样得到一个子图,这里的采样并不是随机采样,而是采取一些策略。典型的图采样算法包括GraphSAGE、PinSAGE等。
文章码源链接:
https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
https://github.com/williamleif/GraphSAGE
前面 GCN 讲解的文章中,我使用的图节点个数非常少,然而在实际问题中,一张图可能节点非常多,因此就没有办法一次性把整张图送入计算资源,所以我们应该使用一种有效的采样算法,从全图中采样出一个子图 ,这样就可以进行训练了。
GraphSAGE与GCN对比:
既然新增的节点,一定会改变原有节点的表示,那么为什么一定要得到每个节点的一个固定的表示呢?何不直接学习一种节点的表示方法。去学习一个节点的信息是怎么通过其邻居节点的特征聚合而来的。 学习到了这样的“聚合函数”,而我们本身就已知各个节点的特征和邻居关系,我们就可以很方便地得到一个新节点的表示了。
GCN等transductive的方法,学到的是每个节点的一个唯一确定的embedding; 而GraphSAGE方法学到的node embedding,是根据node的邻居关系的变化而变化的,也就是说,即使是旧的node,如果建立了一些新的link,那么其对应的embedding也会变化,而且也很方便地学到。
在了解图采样算法前,我们至少应该保证采样后的子图是连通的。例如上图图中,左边采样的子图就是连通的,右边的子图不是连通的。
GraphSAGE的核心:GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射。 GraphSage框架中包含两个很重要的操作:Sample采样和Aggregate聚合。这也是其名字GraphSage(Graph SAmple and aggreGatE)的由来。GraphSAGE 主要分两步:采样、聚合。GraphSAGE的采样方式是邻居采样,邻居采样的意思是在某个节点的邻居节点中选择几个节点作为原节点的一阶邻居,之后对在新采样的节点的邻居中继续选择节点作为原节点的二阶节点,以此类推。
文中不是对每个顶点都训练一个单独的embeddding向量,而是训练了一组aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息(见图1)。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成embedding。
GraphSAGE 是Graph SAmple and aggreGatE的缩写,其运行流程如上图所示,可以分为三个步骤:
- 对图中每个顶点邻居顶点进行采样,因为每个节点的度是不一致的,为了计算高效, 为每个节点采样固定数量的邻居
- 根据聚合函数聚合邻居顶点蕴含的信息
- 得到图中各顶点的向量表示供下游任务使用
邻居采样的优点:
- 极大减少计算量
- 允许泛化到新连接关系,个人理解类似dropout的思想,能增强模型的泛化能力
采样的阶段首先选取一个点,然后随机选取这个点的一阶邻居,再以这些邻居为起点随机选择它们的一阶邻居。例如下图中,我们要预测 0 号节点,因此首先随机选择 0 号节点的一阶邻居 2、4、5,然后随机选择 2 号节点的一阶邻居 8、9;4 号节点的一阶邻居 11、12;5 号节点的一阶邻居 13、15
聚合具体来说就是直接将子图从全图中抽离出来,从最边缘的节点开始,一层一层向里更新节点
上图展示了邻居采样的优点,极大减少训练计算量这个是毋庸置疑的,泛化能力增强这个可能不太好理解,因为原本要更新一个节点需要它周围的所有邻居,而通过邻居采样之后,每个节点就不是由所有的邻居来更新它,而是部分邻居节点,所以具有比较强的泛化能力。
1.1.1 论文角度看GraphSage
聚合函数的选取
在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即也就是对它输入的各种排列,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上。
a. Mean aggregator :
mean aggregator将目标顶点和邻居顶点的第$k−1$层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第$k$层表示向量。
卷积聚合器Convolutional aggregator:
文中用下面的式子替换算法1中的4行和5行得到GCN的inductive变形:
$$
mathbf{h}{v}^{k} leftarrow sigmaleft(mathbf{W} cdot operatorname{MEAN}left(left{mathbf{h}{v}^{k-1}right} cupleft{mathbf{h}_{u}^{k-1}, forall u in mathcal{N}(v)right}right)right) $$
原始算法1中的第4,5行是
$$begin{array}{l}
mathbf{h}{mathcal{N}(v)}^{k} leftarrow operatorname{AGGREGATE}{k}left(left{mathbf{h}_{u}^{k-1}, forall u in mathcal{N}(v)right}right)
mathbf{h}{v}^{k} leftarrow sigmaleft(mathbf{W}^{k} cdot operatorname{CONCAT}left(mathbf{h}{v}^{k-1}, mathbf{h}_{mathcal{N}(v)}^{k}right)right)
end{array}$$
论文提出的均值聚合器Mean aggregator:
$$begin{array}{l}
mathbf{h}{v}^{k} leftarrow sigmaleft(mathbf{W} cdot operatorname{MEAN}left(left{mathbf{h}{v}^{k-1}right} cupleft{mathbf{h}_{u}^{k-1}, forall u in mathcal{N}(v)right}right)right)
mathbf{h}{v}^{k} leftarrow sigmaleft(mathbf{W}^{k} cdot operatorname{CONCAT}left(mathbf{h}{v}^{k-1}, mathbf{h}_{mathcal{N}(v)}^{k}right)right)
end{array}$$
- 均值聚合近似等价在transducttive GCN框架中的卷积传播规则
- 这个修改后的基于均值的聚合器是convolutional的。但是这个卷积聚合器和文中的其他聚合器的重要不同在于它没有算法1中第5行的CONCAT操作——卷积聚合器没有将顶点前一层的表示$mathbf{h}^{k-1}{v}$聚合的邻居向量$mathbf{h}^k{mathcal{N}(v)}$拼接起来
- 拼接操作可以看作一个是GraphSAGE算法在不同的搜索深度或层之间的简单的skip connectionIdentity mappings in deep residual networks的形式,它使得模型的表征性能获得了巨大的提升
- 举个简单例子,比如一个节点的3个邻居的embedding分别为1,2,3,4,2,3,4,5,3,4,5,6按照每一维分别求均值就得到了聚合后的邻居embedding为2,3,4,5
b. LSTM aggregator
文中也测试了一个基于LSTM的复杂的聚合器Long short-term memory。和均值聚合器相比,LSTMs有更强的表达能力。但是,LSTMs不是对称的(symmetric),也就是说不具有排列不变性(permutation invariant),因为它们以一个序列的方式处理输入。因此,需要先对邻居节点随机顺序,然后将邻居序列的embedding作为LSTM的输入。
- 排列不变性(permutation invariance):指输入的顺序改变不会影响输出的值。
c. Pooling aggregator
pooling聚合器,它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的embedding向量进行一次非线性变换,之后进行一次pooling操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。
一个element-wise max pooling操作应用在邻居集合上来聚合信息:
$$begin{aligned}
mathbf{h}{mathcal{N}(v)}^{k}=& mathrm{AGGREGATE}{k}^{text {pool}}=max left(left{sigmaleft(mathbf{W}{text {pool}} mathbf{h}{u}^{k-1} mathbf{b}right), forall u in mathcal{N}(v)right}right)
&mathbf{h}{v}^{k} leftarrow sigmaleft(mathbf{W}^{k} cdot operatorname{CONCAT}left(mathbf{h}{v}^{k-1}, mathbf{h}_{mathcal{N}(v)}^{k}right)right)
end{aligned}$$
其中
- $max$表示element-wise最大值操作,取每个特征的最大值
- $sigma$是非线性激活函数
- 所有相邻节点的向量共享权重,先经过一个非线性全连接层,然后做max-pooling
- 按维度应用 max/mean pooling,可以捕获邻居集上在某一个维度的突出的/综合的表现。
(有监督和无监督)参数学习
在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。
基于图的无监督损失
基于图的损失函数倾向于使得相邻的顶点有相似的表示,但这会使相互远离的顶点的表示差异变大:
$$J mathcal{G}left(mathbf{z}{u}right)=-log left(sigmaleft(mathbf{z}{u}^{T} mathbf{z}{v}right)right)-Q cdot mathbb{E}{v{n} sim P{n}(v)} log left(sigmaleft(-mathbf{z}{u}^{T} mathbf{z}{v_{n}}right)right)$$
其中
- $mathbf{z}_{u}$为节点$u$通过GraphSAGE生成的embedding
- 节点$v$是节点$u$随机游走到达的“邻居”
- $sigma$是sigmoid函数
- $P_{n}$是负采样的概率分布,类似word2vec中的负采样
- $Q$是负样本的数目
- embedding之间相似度通过向量点积计算得到
文中输入到损失函数的表示$mathbf{z}_u$是从包含一个顶点局部邻居的特征生成出来的,而不像之前的那些方法(如DeepWalk),对每个顶点训练一个独一无二的embedding,然后简单进行一个embedding查找操作得到。
基于图的有监督损失
无监督损失函数的设定来学习节点embedding 可以供下游多个任务使用。监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。
参数学习
通过前向传播得到节点$u$的embedding $z_u$,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数$mathbf{W}^{k}$和聚合函数内的参数。
新节点embedding的生成
这个$mathbf{W}^{k}$ 就是所谓的dynamic embedding的核心,因为保存下来了从节点原始的高维特征生成低维embedding的方式。现在,如果想得到一个点的embedding,只需要输入节点的特征向量,经过卷积(利用已经训练好的 $mathbf{W}^{k}$ 以及特定聚合函数聚合neighbor的属性信息),就产生了节点的embedding。
有了GCN为啥还要GraphSAGE?
代码语言:txt复制GCN灵活性差、为新节点产生embedding要求 额外的操作 ,比如“对齐”:
GCN是 直推式(transductive) 学习,无法直接泛化到新加入(未见过)的节点;
GraphSAGE是 归纳式(inductive) 学习,可以为新节点输出节点特征。
GCN输出固定:
GCN输出的是节点 唯一确定 的embedding;
GraphSAGE学习的是节点和邻接节点之间的关系,学习到的是一种 映射关系 ,节点的embedding可以随着其邻接节点的变化而变化。
GCN很难应用在超大图上:
无论是拉普拉斯计算还是图卷积过程,因为GCN其需要对 整张图 进行计算,所以计算量会随着节点数的增加而递增。
GraphSAGE通过采样,能够形成 minibatch 来进行批训练,能用在超大图上
GraphSAGE有什么优点?
代码语言:txt复制采用 归纳学习 的方式,学习邻居节点特征关系,得到泛化性更强的embedding;
采样技术,降低空间复杂度,便于构建minibatch用于 批训练 ,还让模型具有更好的泛化性;
多样的聚合函数 ,对于不同的数据集/场景可以选用不同的聚合方式,使得模型更加灵活。
采样数大于邻接节点数怎么办?
代码语言:txt复制设采样数量为K:
若节点邻居数少于K,则采用 有放回 的抽样方法,直到采样出K个节点。
若节点邻居数大于K,则采用 无放回 的抽样。
训练好的GraphSAGE如何得到节点Embedding?
代码语言:txt复制假设GraphSAGE已经训练好,我们可以通过以下步骤来获得节点embedding,具体算法请看下图的算法1。
训练过程则只需要将其产生的embedding扔进损失函数计算并反向梯度传播即可。
对图中每个节点的邻接节点进行 采样 ,输入节点及其n阶邻接节点的特征向量
根据K层的 聚合函数 聚合邻接节点的信息
就产生了各节点的embedding
minibatch的子图是怎么得到的?
那和DeepWalk、Node2vec这些有什么不一样?
代码语言:txt复制DeepWalk、Node2Vec这些embedding算法直接训练每个节点的embedding,本质上依然是直推式学习,而且需要大量的额外训练才能使他们能预测新的节点。同时,对于embedding的正交变换(orthogonal transformations),这些方法的目标函数是不变的,这意味着生成的向量空间在不同的图之间不是天然泛化的,在再次训练(re-training)时会产生漂移(drift)。
与DeepWalk不同的是,GraphSAGE是通过聚合节点的邻接节点特征产生embedding的,而不是简单的进行一个embedding lookup操作得到。
论文仿真结果:
实验对比了四个基线:随机分类,基于特征的逻辑回归(忽略图结构),DeepWalk算法,DeepWork 特征;同时还对比了四种GraphSAGE,其中三种在3.3节中已经说明,GraphSAGE-GCN是GCNs的归纳版本。具体超参数为:K=2,s1=25,s2=10。程序使用TensorFlow编写,Adam优化器。
对于跨图泛化的任务,需要学习节点角色而不是训练图的结构。使用跨各种生物蛋白质-蛋白质相互作用(PPI)图,对蛋白质功能进行分类。在20个图表上训练算法,2个图用于测试,2个图用于验证,平均每图包含2373个节点,平均度为28.8。从实验结果可以看出LSTM和池化方法比Mean和GCN效果更好。
对比不同聚合函数:
如表-1所示,LSTM和POOL方法效果最好,与其它方法相比有显著差异,LSTM和POOL之间无显著差异,但LSTM比POOL慢得多(≈2x),使POOL聚合器在总体上略有优势。
1.1.2 更多问题
代码语言:txt复制采样
为什么要采样?
采样数大于邻接节点数怎么办?
采样的邻居节点数应该选取多大?
每一跳采样需要一样吗?
适合有向边吗?
采样是随机的吗?
聚合函数
聚合函数的选取有什么要求?
GraphSAGE论文中提供多少种聚合函数?
均值聚合的操作是怎样的?
pooling聚合的操作是怎样的?
使用LSTM聚合时需要注意什么?
均值聚合和其他聚合函数有啥区别?
max-和mean-pooling有什么区别?
这三种聚合方法,哪种比较好?
一般聚合多少层?层数越多越好吗?
什么时候和GCN的聚合形式“等价”?
无监督学习
GraphSAGE怎样进行无监督学习?
GraphSAGE如何定义邻近和远处的节点?
如何计算无监督GraphSAGE的损失函数?
GraphSAGE是怎么随机游走的?
GraphSAGE什么时候考虑边的权重了?
训练
如果只有图、没有节点特征,能否使用GraphSAGE?
训练好的GraphSAGE如何得到节点Embedding?
minibatch的子图是怎么得到的?
增加了新的节点来训练,需要为所有“旧”节点重新输出embeding吗?
GraphSAGE有监督学习有什么不一样的地方吗?
参考链接:https://zhuanlan.zhihu.com/p/184991506
https://blog.csdn.net/yyl424525/article/details/100532849
1.2 PinSAGE
采样时只能选取真实的邻居节点吗?如果构建的是一个与虚拟邻居相连的子图有什么优点?PinSAGE 算法将会给我们解答,PinSAGE 算法通过多次随机游走,按游走经过的频率选取邻居,上图右侧为进行随机游走得到的节点序列,统计序列的频数可以发现节点5,10,11的频数为2,其余为1,当我们希望采样三个节点时,我们选取5,10,11作为0号节点的虚拟邻居。之后如果希望得到0号节点的二阶虚拟邻居则在已采样的节点继续进行随机游走即可。
回到上述问题,采样时选取虚拟邻居有什么好处?这种采样方式的好处是我们能更快的聚合到远处节点的信息。。实际上如果是按照 GraphSAGE 算法的方式生成子图,在聚合的过程中,非一阶邻居的信息可以通过消息传递逐渐传到中心,但是随着距离的增大,离中心越远的节点,其信息在传递过程中就越困难,甚至可能无法传递到;如果按照 PinSAGE 算法的方式生成子图,有一定的概率可以将非一阶邻居与中心直接相连,这样就可以快速聚合到多阶邻居的信息
1.2.1论文角度看PinSAGE
和GraphSAGE相比,PinSAGE改进了什么?
- 采样 :使用重要性采样替代GraphSAGE的均匀采样;
- 聚合函数 :聚合函数考虑了边的权重;
- 生产者-消费者模式的minibatch构建 :在CPU端采样节点和构建特征,构建计算图;在GPU端在这些子图上进行卷积运算;从而可以低延迟地随机游走构建子图,而不需要把整个图存在显存中。
- 高效的MapReduce推理 :可以分布式地为百万以上的节点生成embedding,最大化地减少重复计算。 这里的计算图,指的是用于卷积运算的局部图(或者叫子图),通过采样来形成;与TensorFlow等框架的计算图不是一个概念。
PinSAGE使用多大的计算资源?
代码语言:txt复制训练时,PinSAGE使用32核CPU、16张Tesla K80显卡、500GB内存;
推理时,MapReduce运行在378个d2.8xlarge Amazon AWS节点的Hadoop2集群。
PinSAGE和node2vec、DeepWalk这些有啥区别?
代码语言:txt复制node2vec,DeepWalk是无监督训练;PinSAGE是有监督训练;
node2vec,DeepWalk不能利用节点特征;PinSAGE可以;
node2vec,DeepWalk这些模型的参数和节点数呈线性关系,很难应用在超大型的图上;
PinSAGE的单层聚合过程是怎样的?
和GraphSAGE一样,PinSAGE的核心就是一个 局部卷积算子 ,用来学习如何聚合邻居节点信息。
如下图算法1所示,PinSAGE的聚合函数叫做CONVOLVE。主要分为3部分:
- 聚合 (第1行):k-1层邻居节点的表征经过一层DNN,然后聚合(可以考虑边的权重),是聚合函数符号,聚合函数可以是max/mean-pooling、加权求和、求平均;
- 更新 (第2行): 拼接 第k-1层目标节点的embedding,然后再经过另一层DNN,形成目标节点新的embedding;
- 归一化 (第3行): 归一化 目标节点新的embedding,使得训练更加稳定;而且归一化后,使用近似最近邻居搜索的效率更高。
PinSAGE是如何采样的?
代码语言:txt复制如何采样这个问题从另一个角度来看就是:如何为目标节点构建邻居节点。
和GraphSAGE的均匀采样不一样的是,PinSAGE使用的是重要性采样。
PinSAGE对邻居节点的定义是:对目标节点 影响力最大 的T个节点。
PinSAGE的邻居节点的重要性是如何计算的?
代码语言:txt复制其影响力的计算方法有以下步骤:
从目标节点开始随机游走;
使用 正则 来计算节点的“访问次数”,得到重要性分数;
目标节点的邻居节点,则是重要性分数最高的前T个节点。
这个重要性分数,其实可以近似看成Personalized PageRank分数。
关于随机游走,可以阅读《Pixie: A System for Recommending 3 Billion Items to 200 Million Users in Real-Time》
重要性采样的好处是什么?
代码语言:txt复制和GraphSAGE一样,可以使得 邻居节点的数量固定 ,便于控制内存/显存的使用。
在聚合邻居节点时,可以考虑节点的重要性;在PinSAGE实践中,使用的就是 加权平均 (weighted-mean),原文把它称作 importance pooling 。
采样的大小是多少比较好?
代码语言:txt复制从PinSAGE的实验可以看出,随着邻居节点的增加,而收益会递减;
并且两层GCN在 邻居数为50 时能够更好的抓取节点的邻居信息,同时保持运算效率。
PinSage论文中还介绍了落地过程中采用的大量工程技巧。
- 负样本生成:首先是简单采样:在每个minibatch包含节点的范围之外随机采样500个item作为minibatch所有正样本共享的负样本集合。但考虑到实际场景中模型需要从20亿的物品item集合中识别出最相似的1000个,即要从2百万中识别出最相似的那一个,只是简单采样会导致模型分辨的粒度过粗,分辨率只到500分之一,因此增加一种“hard”负样本,即对于每个 对,和物品q有些相似但和物品i不相关的物品集合。这种样本的生成方式是将图中节点根据相对节点q的个性化PageRank分值排序,随机选取排序位置在2000~5000的物品作为“hard”负样本,以此提高模型分辨正负样本的难度。
- 渐进式训练(Curriculum training:如果训练全程都使用hard负样本,会导致模型收敛速度减半,训练时长加倍,因此PinSage采用了一种Curriculum训练的方式,这里我理解是一种渐进式训练方法,即第一轮训练只使用简单负样本,帮助模型参数快速收敛到一个loss比较低的范围;后续训练中逐步加入hard负样本,让模型学会将很相似的物品与些微相似的区分开,方式是第n轮训练时给每个物品的负样本集合中增加n-1个hard负样本。
- 样本的特征信息:Pinterest的业务场景中每个pin通常有一张图片和一系列的文字标注(标题,描述等),因此原始图中每个节点的特征表示由图片Embedding(4096维),文字标注Embedding(256维),以及节点在图中的度的log值拼接而成。其中图片Embedding由6层全连接的VGG-16生成,文字标注Embedding由Word2Vec训练得到。
- 基于random walk的重要性采样:用于邻居节点采样,这一技巧在上面的算法理解部分已经讲解过,此处不再赘述。
- 基于重要性的池化操作:这一技巧用于上一节Convolve算法中的 函数中,聚合经过一层dense层之后的邻居节点Embedding时,基于random walk计算出的节点权重做聚合操作。据论文描述,这一技巧在离线评估指标中提升了46%。
- on-the-fly convolutions:快速卷积操作,这个技巧主要是相对原始GCN中的卷积操作:特征矩阵与全图拉普拉斯矩阵的幂相乘。涉及到全图的都是计算量超高,这里GraphSage和PinSage都是一致地使用采样邻居节点动态构建局部计算图的方法提升训练效率,只是二者采样的方式不同。
- 生产者消费者模式构建minibatch:这个点主要是为了提高模型训练时GPU的利用率。保存原始图结构的邻居表和数十亿节点的特征矩阵只能放在CPU内存中,GPU执行convolve卷积操作时每次从CPU取数据是很耗时的。为了解决这个问题,PinSage使用re-index技术创建当前minibatch内节点及其邻居组成的子图,同时从数十亿节点的特征矩阵中提取出该子图节点对应的特征矩阵,注意提取后的特征矩阵中的节点索引要与前面子图中的索引保持一致。这个子图的邻接列表和特征矩阵作为一个minibatch送入GPU训练,这样一来,convolve操作过程中就没有GPU与CPU的通信需求了。训练过程中CPU使用OpenMP并设计了一个producer-consumer模式,CPU负责提取特征,re-index,负采样等计算,GPU只负责模型计算。这个技巧降低了一半的训练耗时。
- 多GPU训练超大batch:前向传播过程中,各个GPU等分minibatch,共享一套参数,反向传播时,将每个GPU中的参数梯度都聚合到一起,执行同步SGD。为了适应海量训练数据的需要,增大batchsize从512到4096。为了在超大batchsize下快速收敛保证泛化精度,采用warmup过程:在第一个epoch中将学习率线性提升到最高,后面的epoch中再逐步指数下降。
- 使用MapReduce高效推断:模型训练完成后生成图中各个节点的Embedding过程中,如果直接使用上述PinSage的minibatch算法生Embedding,会有大量的重复计算,如计算当前target节点的时候,其相当一部分邻居节点已经计算过Embedding了,而当这些邻居节点作为target节点的时候,当前target节点极有可能需要再重新计算一遍,这一部分的重复计算既耗时又浪费。
1.2.2更多问题
代码语言:txt复制聚合函数
PinSAGE的单层聚合过程是怎样的?
为什么要将邻居节点的聚合embedding和当前节点的拼接?
采样
PinSAGE是如何采样的?
PinSAGE的邻居节点的重要性是如何计算的?
重要性采样的好处是什么?
采样的大小是多少比较好?
MiniBatch
PinSAGE的minibatch和GraphSAGE有啥不一样?
batch应该选多大?
训练
PinSAGE使用什么损失函数?
PinSAGE如何定义标签(正例/负例)?
PinSAGE用什么方法提高模型训练的鲁棒性和收敛性?
负采样
PinSAGE如何进行负采样?
训练时简单地负采样,会有什么问题?
如何解决简单负采样带来的问题?
如果只使用“hard”负样本,会有什么问题?
如何解决只使用“hard”负采样带来的问题?
如何区分采样、负采样、”hard“负采样?
推理
直接为使用训练好的模型产生embedding有啥问题?
如何解决推理时重复计算的问题?
下游任务如何应用PinSAGE产生的embedding?
如何为用户进行个性化推荐?
工程性技巧
pin样本的特征如何构建?
board样本的特征如何构建?
如何使用多GPU并行训练PinSAGE?
PinSAGE为什么要使用生产者-消费者模式?
PinSAGE是如何使用生产者-消费者模式?
https://zhuanlan.zhihu.com/p/195735468
https://zhuanlan.zhihu.com/p/133739758?utm_source=wechat_session&utm_id=0
1.3 小结
学习大图、不断扩展的图,未见过节点的表征,是一个很常见的应用场景。GraphSAGE通过训练聚合函数,实现优化未知节点的表示方法。之后提出的GAN(图注意力网络)也针对此问题优化。
论文中提出了:传导性问题和归纳性问题,传导性问题是已知全图情况,计算节点表征向量;归纳性问题是在不完全了解全图的情况下,训练节点的表征函数(不是直接计算向量表示)。
图工具的处理过程每轮迭代( 一次propagation)一般都包含:收集信息、聚合、更新,从本文也可以更好地理解,其中聚合的重要性,及优化方法。
GraohSage主要贡献如下:
- 针对问题:大图的节点表征
- 结果:训练出的模型可应用于表征没见过的节点
- 核心方法:改进图卷积方法;从邻居节点中采样;考虑了节点特征,加入更复杂的特征聚合方法
一般情况下一个节点的表式通过聚合它k跳之内的邻近节点计算,而全图的表示则通过对所有节点的池化计算。GIN使用了WL-test方法,即图同构测试,它是一个区分网络结构的强效方法,也是通过迭代聚合邻居的方法来更新节点,它的强大在于使用了injective(见后)聚合更新方法。而这里要评测GNN是否能达到类似WL-test的效果。文中还使用了多合集multiset的概念,指可能包含重复元素的集合。
GIN主要贡献如下:
- 展示了GNN模型可达到与WL-test类似的图结构区分效果
- 设计了聚合函数和Readout函数,使GNN能达到更好的区分效果
- 发现GCN及GraphSAGE无法很好表达图结构,而GNN可以
- 开发了简单的网络结构GIN(图同构网络),它的区分和表示能力与WL-test类似。
2.邻居聚合
在图采样之后,我们需要进行邻居聚合的操作。经典的邻居聚合函数包括取平均、取最大值、求和。
评估聚合表达能力的指标——单射(一对一映射),在上述三种经典聚合函数中,取平均倾向于学习分布,取最大值倾向于忽略重复值,这两个不属于单射,而求和能够保留邻居节点的完整信息,是单射。单射的好处是可以保证对聚合后的结果可区分。
2.1 GIN模型的聚合函数
Graph Isomorphic Net(GIN)的聚合部分是基于单射的。
如上图所示,GIN的聚合函数使用的是求和函数,它特殊的一点是在中心节点加了一个自连边(自环),之后对自连边进行加权。
这样做的好处是即使我们调换了中心节点和邻居节点,得到的聚合结果依旧是不同的。所以带权重的自连边能够保证中心节点和邻居节点可区分。
2.2其他复杂的聚合函数
2.3 令居聚合语义场景
3.数据集介绍
数据源:http://snap.stanford.edu/graphsage/ 斯坦福
3.1 Citation数据集
使用科学网引文数据集,将学术论文分类为不同的主题。数据集共包含302424个节点,平均度9.15,使用2000-2004年数据作为训练集,2005年数据作为测试集。使用节点的度以及论文摘要的句嵌入作为特征。
3.2 Reddit数据集
https://aistudio.baidu.com/aistudio/datasetdetail/177810
将Reddit帖子归类为属于不同社区。数据集包含232965个帖子,平均度为492。使用现成的300维GloVe Common Crawl单词向量;对于每个帖子,使用特征包含:(1) 帖子标题的平均嵌入 (2) 帖子所有评论的平均嵌入 (3) 帖子的分数 (4)帖子的评论数量
为了对社区进行抽样,根据 2014 年的评论总数对社区进行了排名,并选择了排名 11,50(含)的社区。省略了最大的社区,因为它们是大型的通用默认社区,大大扭曲了类分布。选择了在这些社区的联合上定义的图中最大的连通分量。
更多数据资料见:
http://files.pushshift.io/reddit/comments/
https://github.com/dingidng/reddit-dataset
最新数据已经更新到2022.10了
3.3 PPI(Protein–protein interactions)蛋白质交互作用
https://aistudio.baidu.com/aistudio/datasetdetail/177807
PPI 网络是蛋白质相互作用(Protein-Protein Interaction,PPI)网络的简称,在GCN中主要用于节点分类任务
PPI是指两种或以上的蛋白质结合的过程,通常旨在执行其生化功能。
一般地,如果两个蛋白质共同参与一个生命过程或者协同完成某一功能,都被看作这两个蛋白质之间存在相互作用。多个蛋白质之间的复杂的相互作用关系可以用PPI网络来描述。
PPI数据集共24张图,每张图对应不同的人体组织,平均每张图有2371个节点,共56944个节点818716条边,每个节点特征长度为50,其中包含位置基因集,基序集和免疫学特征。基因本体基作为label(总共121个),label不是one-hot编码。
- alid_feats.npy文件保存节点的特征,shape为(56944, 50)(节点数目,特征维度),值为0或1,且1的数目稀少
- ppi-class_map.json为节点的label文件,shape为(121, 56944),每个节点的label为121维
- ppi-G.json文件为节点和链接的描述信息,节点:{"test": true, "id": 56708, "val": false}, 表示节点id为56708的节点是否为test集或者val集,链接:"links": [{"source": 0, "target": 372}, {"source": 0, "target": 1101}, 表示节点id为0的节点和为1101的节点之间有links,
- ppi-walks.txt文件中为链接信息
- ppi-id_map.json文件为节点id信息
参考链接:
https://blog.csdn.net/ziqingnian/article/details/112979175
4 基于PGL算法实践
4.1 GraphSAGE
GraphSAGE是一个通用的归纳框架,它利用节点特征信息(例如,文本属性)为以前看不见的数据有效地生成节点嵌入。GraphSAGE 不是为每个节点训练单独的嵌入,而是学习一个函数,该函数通过从节点的本地邻域中采样和聚合特征来生成嵌入。基于PGL,我们重现了GraphSAGE算法,在Reddit Dataset中达到了与论文同等水平的指标。此外,这是PGL中子图采样和训练的一个例子。
超参数
代码语言:txt复制epoch: Number of epochs default (10)
normalize: Normalize the input feature if assign normalize.
sample_workers: The number of workers for multiprocessing subgraph sample.
lr: Learning rate.
symmetry: Make the edges symmetric if assign symmetry.
batch_size: Batch size.
samples: The max neighbors for each layers hop neighbor sampling. (default: [25, 10])
hidden_size: The hidden size of the GraphSAGE models.
代码语言:txt复制parser = argparse.ArgumentParser(description='graphsage')
parser.add_argument(
"--normalize", action='store_true', help="normalize features") # normalize:归一化节点特征
parser.add_argument(
"--symmetry", action='store_true', help="undirect graph") # symmetry:聚合函数的对称性
parser.add_argument("--sample_workers", type=int, default=5) # sample_workers:多线程数据读取器的线程个数
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--hidden_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--samples', nargs=' ', type=int, default=[25, 10]) # samples_1:第一级邻居采样时候选择的最大邻居个数(默认25)#,samples_2:第而级邻居采样时候选择的最大邻居个数(默认10)
部分结果展示:
代码语言:txt复制[INFO] 2022-11-18 16:45:44,177 [ train.py: 63]: Batch 800 train-Loss [0.5213774] train-Acc [0.9140625]
[INFO] 2022-11-18 16:45:45,783 [ train.py: 63]: Batch 900 train-Loss [0.65641916] train-Acc [0.875]
[INFO] 2022-11-18 16:45:47,385 [ train.py: 63]: Batch 1000 train-Loss [0.57411766] train-Acc [0.921875]
[INFO] 2022-11-18 16:45:48,977 [ train.py: 63]: Batch 1100 train-Loss [0.68337256] train-Acc [0.890625]
[INFO] 2022-11-18 16:45:50,434 [ train.py: 160]: Runing epoch:9 train_loss:[0.58635516] train_acc:[0.90786038]
[INFO] 2022-11-18 16:45:57,836 [ train.py: 165]: Runing epoch:9 val_loss:0.55885834 val_acc:0.9139818
[INFO] 2022-11-18 16:46:05,259 [ train.py: 169]: Runing epoch:9 test_loss:0.5578749 test_acc:0.91468066
100%|███████████████████████████████████████████| 10/10 [06:02<00:00, 36.29s/it]
[INFO] 2022-11-18 16:46:05,260 [ train.py: 172]: Runs 0: Model: graphsage Best Test Accuracy: 0.918849
目前官网最佳性能是95.7%,我这里没有调参
Aggregator | Accuracy_me_10 epochs | Accuracy_200 epochs | Reported in paper_200 epochs |
---|---|---|---|
Mean | 91.88% | 95.70% | 95.0% |
其余聚合器下官网和论文性能对比:
Aggregator | Accuracy_200 epochs | Reported in paper_200 epochs |
---|---|---|
Meanpool | 95.60% | 94.8% |
Maxpool | 94.95% | 94.8% |
LSTM | 95.13% | 95.4% |
4.2 Graph Isomorphism Network (GIN)
图同构网络(GIN)是一个简单的图神经网络,期望达到Weisfeiler-Lehman图同构测试的能力。基于 PGL重现了 GIN 模型。
超参数
- data_path:数据集的根路径
- dataset_name:数据集的名称
- fold_idx:拆分的数据集折叠。这里我们使用10折交叉验证
- train_eps:是否参数是可学习的。
parser.add_argument('--data_path', type=str, default='./gin_data')
parser.add_argument('--dataset_name', type=str, default='MUTAG')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--fold_idx', type=int, default=0)
parser.add_argument('--output_path', type=str, default='./outputs/')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--num_mlp_layers', type=int, default=2)
parser.add_argument('--feat_size', type=int, default=64)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument(
'--pool_type',
type=str,
default="sum",
choices=["sum", "average", "max"])
parser.add_argument('--train_eps', action='store_true')
parser.add_argument('--init_eps', type=float, default=0.0)
parser.add_argument('--epochs', type=int, default=350)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--dropout_prob', type=float, default=0.5)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
GIN github代码复现含数据集下载:How Powerful are Graph Neural Networks? https://github.com/weihua916/powerful-gnns
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
论文使用 9 个图形分类基准:4 个生物信息学数据集(MUTAG、PTC、NCI1、PROTEINS) 和 5 个社交网络数据集(COLLAB、IMDB-BINARY、IMDB-MULTI、REDDITBINARY 和 REDDIT-MULTI5K)(Yanardag & Vishwanathan,2015)。 重要的是,我目标不是让模型依赖输入节点特征,而是主要从网络结构中学习。因此,在生物信息图中,节点具有分类输入特征,但在社交网络中,它们没有特征。 对于社交网络,按如下方式创建节点特征:对于 REDDIT 数据集,将所有节点特征向量设置为相同(因此,这里的特征是无信息的); 对于其他社交图,我们使用节点度数的 one-hot 编码。
社交网络数据集。
- IMDB-BINARY 和 IMDB-MULTI 是电影协作数据集。每个图对应于每个演员/女演员的自我网络,其中节点对应于演员/女演员,如果两个演员/女演员出现在同一部电影中,则在两个演员/女演员之间绘制一条边。每个图都是从预先指定的电影类型派生的,任务是对其派生的类型图进行分类。
- REDDIT-BINARY 和 REDDIT-MULTI5K 是平衡数据集,其中每个图表对应一个在线讨论线程,节点对应于用户。如果其中至少一个节点回应了另一个节点的评论,则在两个节点之间绘制一条边。任务是将每个图分类到它所属的社区或子版块。
- COLLAB 是一个科学协作数据集,源自 3 个公共协作数据集,即高能物理、凝聚态物理和天体物理。每个图对应于来自每个领域的不同研究人员的自我网络。任务是将每个图分类到相应研究人员所属的领域。
生物信息学数据集。
- MUTAG 是一个包含 188 个诱变芳香族和杂芳香族硝基化合物的数据集,具有 7 个离散标签。
- PROTEINS 是一个数据集,其中节点是二级结构元素 (SSE),如果两个节点在氨基酸序列或 3D 空间中是相邻节点,则它们之间存在一条边。 它有 3 个离散标签,代表螺旋、薄片或转弯。
- PTC 是一个包含 344 种化合物的数据集,报告了雄性和雌性大鼠的致癌性,它有 19 个离散标签。
- NCI1 是由美国国家癌症研究所 (NCI) 公开提供的数据集,是经过筛选以抑制或抑制一组人类肿瘤细胞系生长的化学化合物平衡数据集的子集,具有 37 个离散标签。 部分结果展示:[INFO] 2022-11-18 17:12:34,203 [ main.py: 98]: eval: epoch 347 | step 2082 | | loss 0.448468 | acc 0.684211 [INFO] 2022-11-18 17:12:34,297 [ main.py: 98]: eval: epoch 348 | step 2088 | | loss 0.393809 | acc 0.789474 [INFO] 2022-11-18 17:12:34,326 [ main.py: 92]: train: epoch 349 | step 2090 | loss 0.401544 | acc 0.8125 [INFO] 2022-11-18 17:12:34,391 [ main.py: 98]: eval: epoch 349 | step 2094 | | loss 0.441679 | acc 0.736842 [INFO] 2022-11-18 17:12:34,476 [ main.py: 92]: train: epoch 350 | step 2100 | loss 0.573693 | acc 0.7778 [INFO] 2022-11-18 17:12:34,485 [ main.py: 98]: eval: epoch 350 | step 2100 | | loss 0.481966 | acc 0.789474 [INFO] 2022-11-18 17:12:34,485 [ main.py: 103]: best evaluating accuracy: 0.894737
结果整合:(这里就不把数据集一一跑一遍了)
MUTAG | COLLAB | IMDBBINARY | IMDBMULTI | |
---|---|---|---|---|
PGL result | 90.8 | 78.6 | 76.8 | 50.8 |
paper reuslt | 90.0 | 80.0 | 75.1 | 52.3 |
原论文所有结果: