优于GNN嵌入基线,阿尔伯塔大学等用RL做图关系推理:关系预测任务新SOTA

2022-02-24 08:20:16 浏览数 (1)

机器之心报道

机器之心编辑部

具备强系统性、对噪声数据具有稳健性,阿尔伯塔大学和蒙特利尔大学 MILA 研究所联合推出了一个基于 RL 的图关系推理框架,并在多个数据集上实现了 SOTA。

智能的一个重要组成部分是推理,即观察数据中不同事物之间的关系,并归纳总结出这些关系之间的推理规则,以进行可解释和可泛化的逻辑推理 。

本文中,阿尔伯塔大学和蒙特利尔大学 MILA 研究所的研究者提出了一种基于强化学习的图关系推理框架 R5,可以从观察到的数据中明确挖掘潜在的组合逻辑规则。该方法将图数据抽象为关系与关系之间的组合,利用配备蒙特卡罗树搜索的策略价值网络执行循环关系预测,并利用回溯重写机制来挖掘显式的规则。R5 在关系预测任务上优于各种基于图神经网络嵌入和规则归纳的基线,同时在发现事实规则方面实现了高召回率

论文链接:https://openreview.net/forum?id=2eXhNpHeW6E

系统性(Systematicity),即重组已知部分和规则以形成新序列同时对关系数据进行推理的能力,对于机器智能至关重要。具有强系统性的模型能够训练小规模任务并推广到大规模任务。作为一种基于强化学习的关系推理框架,R5 对关系图数据进行推理,并从观察中明确挖掘潜在的组合逻辑规则。

R5 系统性强,对噪声数据具有稳健性。它由配备蒙特卡罗树搜索(MCTS)的策略价值网络执行循环关系预测,并由用于规则挖掘的回溯重写(Backtrack Rewritting)机制组成。通过交替地应用这两个组件,R5 逐步从数据中学习一组显式规则,并执行可解释和可概括的关系预测。

研究者对多个数据集进行了广泛的评估。实验结果表明,R5 在关系预测任务上优于各种基于嵌入和规则归纳的基线,同时在发现基本事实规则(ground truth rules)方面实现了高召回率。论文已被 ICLR 2022 接收为 Spotlight,相关代码即将开源

引言

本文研究的问题是关系预测,如下图 c 所示,给定一个关系图 G,和一对 queried nodes q=(Mary, Ann),研究者需要预测 Mary 和 Ann 之间的关系 r。

本文所关注的系统性(Systematicity)要求具有强系统性的模型能够训练小规模任务并推广到大规模任务,如图所示,仅在 a&b 上训练的强系统性模型,应当能够解决 c 中的问题。举例来说,从 a 和 b 中可分别习得规则 (1) Mother (X,Y) ← Mother (X,Z), Sister (Z,Y),和规则 (2) Grandma (X,Y) ←Mother (X,Z), Father (Z,Y)。则在 c 中,我们能根据规则 (1) 得到 Mary 是 Peter 的 Mother,然后根据规则 (2) 得到 Mary 是 Ann 的 Grandma。

方法

本文提出了一个新颖的图推理(graph reasoning)框架 R5,即 Rule discovery with Reinforced and Recurrent Relational Reasoning。R5 将关系推理想象成顺序决策问题,并通过配备动态规则记忆(dynamic memory)的深度强化学习进行规则提取和逻辑推理。

更具体地说,R5 学习的是形式为 u ← pi ∧ pj 的短定句(short definite clause)。由于长 Horn 子句可以分解为短 Horn 子句,因此预测任务中需要用到的长确定子句(long definite clause,如 Figure 1c)“outcome ←p0 ∧ p1...pi ∧ pj ...pk” 可以在执行决策过程中用一系列短确定子句表示,比如 pi ∧ pj 可以被 u 代替。

下图为 R5 的基本框架。本文假设关系推理与 node identity 无关,因此对于每张输入图,本文穷举或采样出其中的关系路径,且不再考虑 nodes,接下来的框架只在采样出来的这些路径上做文章。模型可分为两部分(1)由 MCTS 和策略价值网络组成的深度强化学习模型;(2)可重写的动态规则记忆。

循环关系预测

策略价值网络以当前状态 s 作为输入,输出动作概率分布 ρ 和一个状态值 ν。MCTS 利用策略网络 fθ(s) 来指导其模拟,并输出一个向量 π,表示可用操作的改进搜索概率。然后训练策略网络 (ρ, ν) = fθ(s) 以最小化预测状态值 ν 与情节后收到的奖励 z 之间的误差,并最大化两个概率分布 ρ 和 π 之间的相似性。损失函数为

  • State:本文将策略网络中的 state 想象为关系与关系之间的组合(relation pair)。如果数据中有 m 种关系,和 n 种预先定义并分配的 unknown/inverted relation,则它们两两结合总共可以有 (m n)*(m n) 种组合,如上图的 “Relation pairs to state”。接下来考虑为刚才采样出的 paths 中的所有相连的关系组合(比如上图 Path 1 中的 (r1,r2) 或 (r2,r5))填入相关的特征,不在 paths 中的组合所有特征将被设置为 0。特征可包括:该关系组合最常出现在 path 的第几位、出现了几次、这个关系组合是否已经存在于动态记忆里等。本文共采用了 8 个特征,具体可参照下图 4 中的例子和论文原文;
  • Action:本文的 action 是在 path 中合并相连的两个关系,并将其合并为一个新的关系。比如在上图中,第一步 action 为将 (r1,r2) 合并为 r3,第二步为将 (r5,r3) 合并为 r21。迭代地执行这样的操作,直至 query (X,Y) 之间得到一个单独的关系,则将其作为对 (X,Y) 之间的关系预测;
  • Reward:reward 在 episode 结束时计算,如果准确预测了 (X,Y) 的关系则 reward=1,没有预测出合理的关系则 reward=0,如果预测错了则 reward=-1。(详见论文原文)。

可重写的动态规则记忆

如上图 2 所示,每次执行 action 的时候,策略网络和 MCTS 只能给出规则左边的关系组合,却无法给出要合并成什么关系,因此本文引入了动态规则记忆(Dynamic rule memory)。每采取一个 action,模型都会以所选择的关系组合为 key,在记忆中查找要合并的关系。如果关系组合不在记忆中,则合并为随机一个未被占用的 invented 关系,并将其作为规则记录到记忆区。

在训练中,当一个 episode 结束时,检查最终得到的关系是否和数据中的答案一致,若不一致,则需要根据答案更新记忆区,更新方法如下算法 1 所示。举例来说,如果模型给出要预测的关系是一个 invented 关系 u,而实际训练数据指向一个已知的关系 r6,则记忆区中的所有 u 都会被替换为 r6,下次再查找的时候,模型就知道这条规则的关系组合应该被合并为 r6。

但只这样操作的话,预先分配的 invented 关系很快就会被占满,我们需要一个方式把用不到的 invented 关系合理释放,因此本文又引入了规则打分机制,并将累计分值存储在记忆区,当 invented 关系被占满的时候,就从分值低的开始释放。打分标准可包括:是否已经在记忆里、有没有成功帮助 query 关系的预测等,详见论文原文。

实验

在 CLUTRR 和 GraphLog 数据集上,R5 均取得了 SOTA 的表现,其中 CLUTRR 的数据较为干净,而 GraphLog 则包含较多错误数据。可以看到,R5 可以在非常小的数据上进行训练(每张图仅包含 2~4 个 node),却可以将学习到的规则泛化到较大的图上,它们的推理路径可能长达 15 步。值得一提的是,GraphLog 数据集给出了生成每个小数据集所用到的规则,虽然在训练时不使用这些信息,但可以用其验证模型对规则提取的召回率,R5 在大部分实验的数据集上可达到接近 100% 的召回率。

本文还做了一些关于策略价值网络,和 invented 关系的消融实验。实验结果表明,这两个模块对 R5 都是必要的,策略价值网络保证准确性,invented 关系加速收敛。

© THE END

转载请联系本公众号获得授权

0 人点赞