神经网络的多任务学习方法,避免灾难性遗忘

2021-09-15 14:14:33 浏览数 (1)

神经网络非常擅长学习一件事。无论是下棋还是折叠蛋白质,只要有足够的数据和时间,神经网络都能取得惊人的效果。不幸的是,网络目前无法擅长一项以上的任务。你可以训练一个网络擅长某件事,但是一旦你试图教给网络其他东西,它就会忘记它在第一个任务中学到的东西。这被称为灾难性遗忘(catastrophic forgetting)。由于智能的标志之一是学习和存储多项任务,因此如何在多项任务上训练神经网络(并解决灾难性遗忘)的问题极为重要。

有一种想法是改变训练数据,使每项任务的训练样本相互交错。例如,假设我们有 3 个任务 A、B 和 C,示例分别标记为 a_i、b_i 和 c_i。这个想法将训练集排序为 a_1、b_1、c_1、a_2、b_2、c_2、a_3 等。重点是同时训练所有三个任务并同等重视,希望网络权重以这种方式学习将包含有关所有三个任务的信息。在实践中,这个想法是有效的——网络会同时在所有三个任务上慢慢变得更好,避免灾难性的遗忘。

但是,在我看来,这种方法是作弊的。真正的智能学习(比如人类的寻恶习方式)不需要以这种方式交错任务。事实上,通常任务是在大时段中顺序学习的(我们称之为“时段排程”)。典型的学校是这样安排的:首先你学习 1 小时数学,然后是 1 小时英语,然后再学习 1 小时历史,依此类推。而不是做数学题,写一篇文章,然后读一本历史书,这样的分裂时间。所以问题是:有没有办法让神经网络通过时段排程来学习多个任务?

2016 年,Deepmind 的研究人员发表了一篇解决这个问题的论文。我特别喜欢他们的方法,因为它并不复杂。他们真正做的就是对网络应用一种特殊类型的正则化。让我们仔细看看。

假设我们有两个任务,A 和 B。时段排程是首先用许多示例训练 A,然后我们切换到训练 B。Deepmind 研究人员建议先正常训练 A(即常规梯度下降/反向传播)。然后在 B 块期间我们保留从 A 学到的权重并继续梯度下降。唯一的区别是我们现在包含一个对每个权重都是唯一的二次正则化项。这个想法是使用这种按权重正则化来惩罚远离从 A 学习的权重。被认为对 A 更重要的权重将受到更重的惩罚。在数学上,我们在 B 训练时段排程的成本函数是 L(θ) = L_B(θ) Σ(k_i * (θAi — θi)²),其中 i 是所有网络权重的索引,θ_Ai 是权重A训练块完成后。L_B(θ) 是 B 的正常成本函数,可能是平方误差或对数损失。最后,k_i 是权重 i 对于预测 A 的重要性。

还有一种更直观的方式来考虑每个权重的正则化。想象一个物理弹簧,当拉动弹簧时,拉得越远弹簧拉回的力就越大。此外一些弹簧比其他弹簧更坚固。这与我们的算法有什么关系?你可以想象一个弹簧连接到神经网络中的每个权重。所有弹簧的相对强度是正则化的。某些弹簧(对 A 很重要)会非常强,所以在 B 的训练过程中,算法将不鼓励拉那些强弹簧,相应的权重不会有太大变化。因此,该算法将改为拉动较弱的弹簧,并且与这些弹簧相对应的权重将发生更多变化。

考虑这个算法的另一种方式是它是对 L2 正则化的改进。使用 L2 正则化,不鼓励权重改变太多,惩罚对应于权重的平方和。在 L2 正则化中,所有的权重都受到同等的惩罚。在我们的算法中,只有重要的权重被阻止改变。

好的——我们现在直观地理解了这个算法是如何工作的。通过保持 A 的重要权重相对恒定,我们可以在 A 上保持性能,同时在 B 上成功训练。但是,我们仍然没有解释如何确定 A 的“重要”权重。那么让我们问一个问题:什么使权重重要?一个合理的答案可能是:如果权重比其他权重对最终预测的影响更大,那么它就很重要。更具体地说,如果权重相对于最终预测的导数比其他权重导数具有更高的幅度,我们可以说权重很重要。不过,我们遗漏了一件事——因为神经网络中的权重会影响其他权重,它们的导数相互关联。换句话说,我们不能只考虑给定权重的导数;我们需要查看所有权重导数的协方差矩阵。一个更正式的版本被称为 Fisher 信息矩阵,这是研究人员最终使用的。

所以,总结一下:首先我们正常训练 A,然后我们用每个权重的二次正则化训练 B。这些按权重的正则化取决于权重对 A 的相对重要性,这可以通过 Fisher 信息矩阵找到。结果是一个对 A 和 B 都适用的神经网络。研究人员给这种方法起的名字是“弹性权重合并”(EWC)。

所有这些理论讨论都很棒,但它真的有效吗?研究人员在监督学习和强化学习环境中测试了这种新方法以找出答案。首先,为了测试监督学习,研究人员使用了流行的 MNIST 任务。为了从 MNIST 创建多个任务,他们获取输入的 MNIST 图像并通过几个固定常量对它们进行排列。这个想法是在一个排列(MNIST 图像 const1)上训练的分类器不会在另一个(MNIST 图像 const2)上工作,所以实际上,我们有不同的任务。研究人员随后比较了 EWC、规则梯度下降和 L2 正则化的结果。在第一个任务(MNIST 图像的第一次排列)中,所有三种方法都具有可比性。随着越来越多的任务被引入(其他排列),EWC 远远优于其他任务。

强化学习也有类似的结果。在这里,实验是学习十种不同的 Atari 游戏。比赛的顺序是随机的。同样,对于第一场比赛,代理使用 EWC 获得了与使用基线 (DQN) 相似的性能。随着更多游戏的推出,EWC 的表现也类似在监督学习的表现。

0 人点赞