◆ 动机
图神经网络(Graph Neural Networks)在图表示学习任务中获得了空前的成功。然而和深度学习的领域相比,图神经网络一个显著的特征是,网络在浅层的时候(层数只有2-3层)就取得了最好的表现。如果我们继续加深图神经网络,那么其表现反而会快速下降。这和深度学习中的内核“深度”二字相违背。
训练集和测试集准确率v.s.模型深度
为了探究为什么图神经网络会表现出这样的行为,以及设计出新的算法来提升深度图神经网络的表现。我们从网络的可训练性(trainability)角度来探究深度图神经网络背后的秘密,最终从理论上证明在一定条件下,图神经网络的可训练性以指数的速率下降。结合理论分析,我们受到统计物理中逾渗(percolation)启发提出来Critical DropEdge的算法,一种连通性感知和图自适应的采样方法,从根本上解决可训练性指数衰减问题。
论文链接:https://arxiv.org/abs/2103.03113
ICLR 2022 Poster: https://iclr.cc/virtual/2022/poster/6585
◆ Graph Neural Tangent Kernel
我们知道无限宽神经网络在梯度下降算法下的动力学由Neural Tangent Kernel (NTK)来描述。由于图神经网络也是一种神经网络,因此将图神经网络无限宽化,其动力学就由Graph Neural Tangent Kernel (GNTK)来描述。
具体而言,网络的损失函数的收敛速度将由NTK的最小特征值来决定,而当一个网络对应的NTK变成奇异矩阵的时候,那么这个网络的损失函数将会无法收敛,从而散失了可训练性。
根据以上的背景,我们将GNTK作为一个理论工具,来刻画图神经网络的可训练性和网络深度的关系。我们希望观察随着图神经网络的深度加深,其对应的GNTK会有怎样的行为。
◆ 理论结果
我们首先研究一个普通版本的图神经网络,其结构如下:
网络由L个传播单位,L也即网络的深度,其中一个单元由一次聚合操作和R次MLP所组成。
经过理论推导,我们获得了第一个定理:
这个定理告诉我们在数据图是连通的情况下,GNTK矩阵会随着深度的增加而趋于一个常数矩阵(矩阵中所有元素都是一样的值),而且这个收敛的速率是指数的。这意味着深度图网络会以一个恐怖的速度丢失可训练性,非常可怕。
接下来我们用理论框架进一步分析了带有残差连接的图网络结构,发现指数衰减无法避免,好消息是指数衰减的速度会比没有残差连接的结构要慢
具体而言就是二者对应的概率转移矩阵的第二大特征值会不一样。而第二个特征值和指数衰减因子息息相关。
我们最终通过数值模拟可以验证上述的定理:
其中,第二排第一个图表明了GNTK的指数衰减速率,第二排第二个图表明了残差连接相对会减缓衰减速度,但是其依然是指数衰减。
◆ Critical DropEdge
为了从根本上解决可训练性随着网络深度增加而出现指数衰减的问题,我们从理论推导中分析发现聚合操作在GNTK的递推中对应着概率转移矩阵。而概率转移矩阵就意味了马尔可夫过程,这就引发了后面的指数衰减。我们加入的残差连接会增大概率转移矩阵的第二大特征值从而减缓衰减速度。不过终归还是一个概率转移矩阵。所以我们需要换一个角度,从根本上破坏这个概率转移矩阵。
我们发现聚合操作对应成概率转移矩阵的一个必要条件是图是连通的。所以突破口就在于破坏图的连通性。刚好统计物理中有一个很好的模型告诉我们当一个随机图中边的连接概率取一个特定值的时候,那么图会展现出一种临界现象:从整体看,整张图依然具有连通性(存在一个和图尺寸有关的大集团),而同时信息在图上面传递的速率是多项式的。这样我们既不会因为连接概率过小而导致整个图变得支离破碎,进而没法学习图的有效信息,也不会连接概率过大,依然会存有指数性传播的缺点。最终我们提出来Critical DropEdge算法,一种连通性感知和图自适应的采样方法。
其实验效果由以下表格显示:
我们在点分类任务上测试,其中C-DropEdg是我们提出来的方法,GCN,DropEdge和DGN都是用于对比的方法。可以看出来在网络很深的时候,Critical DropEdge依然可以获得很高的表现。
来源:
https://www.toutiao.com/article/7089623667306414603/?log_from=e4f1dcdc8bb8_1651217350891
“IT大咖说”欢迎广大技术人员投稿,投稿邮箱:aliang@itdks.com
来都来了,走啥走,留个言呗~
IT大咖说 | 关于版权
由“IT大咖说(ID:itdakashuo)”原创的文章,转载时请注明作者、出处及微信公众号。投稿、约稿、转载请加微信:ITDKS10(备注:投稿),茉莉小姐姐会及时与您联系!
感谢您对IT大咖说的热心支持!
- 相关推荐
- 推荐文章
- 2 万字详解,彻底讲透 Elasticsearch
- 一款 IDEA 插件帮你优雅转化 DTO、VO、BO、PO、DO
- 「开源」数据同步ETL工具,支持多数据源间的增、删、改数据同步
- 如何使用 SSHGUARD 阻止 SSH 暴力攻击
- 实时时间序列异常检测
- [开源]一套BS架构,支持PC、H5端的开源知识管理系统、知识库系统
- 后端开发常见层式结构设计:跳表、时间轮、LSM-Tree
- 16 个有用的带宽监控工具来分析 Linux 中的网络使用情况
- Redis 中的过期删除策略和内存淘汰机制
- 一个可以测试并发数和运行次数的压力测试代码
- linux远程桌面管理工具xrdp