深度学习中的样本遗忘问题 (ICLR-2019)

2022-03-28 18:26:12 浏览数 (2)

  • 标题:An Empirical Study of Example Forgetting during Deep Neural Network Learning
  • 会议:ICLR-2019
  • 机构:CMU,MSR,MILA

一句话总结: 学了忘,忘了学。你我如此,神经网络也如此。在深度模型训练过程中,可能发生了大量的、反复的样本遗忘现象。

本论文的写法很特别,跟大家常读的八股文不同,本文更像一个实验报告,标题也说了是一个Empirical Study,我觉得是一个很好的写Empirical Study的范本,值得收藏。

一、跟“灾难性遗忘”的关系

灾难性遗忘(catastrophic forgetting),是一个在深度学习中常被提起的概念,也是lifelong learning, continual learning中研究的主要问题之一。

灾难性遗忘,描述的是在一个任务上训练出来的模型,如果在一个新任务上进行训练,就会大大降低原任务上的泛化性能,即之前的知识被严重遗忘了。 在论文Attention-Based Selective Plasticity中的一幅图很形象地描述了这个概念:

灾难性遗忘,来源:论文Attention-Based Selective Plasticity

而本文提出的样本遗忘(example forgetting),则是受到灾难性遗忘现象的启发而提出的,即在同一个任务的训练过程中,也可能会有遗忘现象,一个样本可能在训练过程中反复地学了忘,忘了学。

实际上,如果我们把任务的概念放宽,那么我每一个mini-batch都可以看做一个小task,所以这里的example forgetting,就是更微观视角的catastrophic forgetting.

二、概念定义

1. Forgetting & Learning events

当一个样本本来预测对的,现在预测错了,就是一次forgetting event;相反的就是learning event.

我们会初始化一开始每个样本的预测都是不对的,但是在经过训练后(比如一个batch之后)进行上述的检查。

2. Classification margin

分类边际,被定义为:正确的类别对应的logit,跟其他类别中最大的logit的差。

3. Forgettable & Unforgettable examples

  • 被遗忘至少一次的,就叫forgettable example
  • 在某时刻被学习到了,然后从此就没有被遗忘过的样本,就叫unforgettable example
  • 从未被学习到的(即自始至终都预测是错的),不能算作unforgettable(但是,自始至终预测都是对的,就算)

三、实验设置&统计流程:

统计算法

上述统计算法更加清晰地告诉我们本文是如何进行对forgetting events进行统计的,即我们是在每个batch训练完之后统计一次。

本文使用了三个数据集:MNIST, permuted-MNIST(MNIST的像素重排版)和CIFAR-10,这三个数据集的学习难度是递增的。

四、☆ 实验观察

这一部分就是本论文的主要部分了,没有太多的理论,主要就是通过一系列的实验来向我们展示训练过程中发生了什么,但真的都挺有意思的,能给人带来很多启发和思考。

1. 遗忘次数的统计

number of forgetting events

从上图可以看出,随着数据集的复杂度和多样性(complexity & diversity)的增加,样本遗忘的情况越来越多。简单的数据集,有大量的unforgettable examples. 作者统计如下:

dataset

# unforgettable examples

MNIST

91.7%

permuted-MNIST

75.3%

CIFAR-10

31.3%

另外,有些样本遗忘,可能是随机发生的,就是模型自己随便更新都可能造成遗忘,所以作者们专门做了一个统计,让模型用随机的梯度来更新,看看遗忘的情况:

forgetting by chance

可见随机遗忘的分布,跟真实遗忘的分布还是有很大差别的,而且随机遗忘的次数会很少,一般在2次以内。

2. 何时被第一次学到

一个样本究竟出现几次才会被模型学到?这是一个很有意思的问题,作者分别对unforgettable和forgettable的样本进行了统计:

first learning event

从上图可以发现,大部分的样本,在出现5次以内就可以被学习到。相比而言,unforgettable样本更早被学到。

3. 遗忘次数跟misclassification margin的关系

前面定义了classification margin,而misclassification margin这里定义为一个样本在所有forgetting events中的平均classification margin,所以这个的绝对值越大,就代表分类的模糊程度越大。

misclassification margin

上图是一个2D的直方图,代表了所有样本是如何分布的。总体上看,forgetting次数多的样本,其misclassification margin也很大。

4. 发现噪音样本

我们很自然可以想到,能否利用遗忘次数,来判断一个样本是否是噪音(标签错误)呢?作者从数据集中随机挑选了20%的样本改变其标签,然后做了如下统计:

noise detection

发现,噪音样本跟正常样本在遗忘次数上,分布十分不一样,遗忘次数会显著多于正常样本。因此我们可以利用这个特点,来帮助我们对数据集去噪,例如最近的文章DataCLUE: A Benchmark Suite for Data-centric NLP中就使用了这种方法。

上面展示的是label noise的结果,作者在附录部分还附上了对input添加noise的实验,也挺有意思的:

pixel noise

发现,对样本(图片)添加的noise越大,这个forgetting的统计就越接近一个正态分布,这也一定程度上反映了分类任务越难,样本遗忘的情况就越严重。

5. 微观视角的灾难性遗忘

这是一个很有意思的实验。

上面的很多分析都验证了神经网络确实会有遗忘,即使在同一个任务的训练中。为了跟经典的灾难性遗忘进行对照,作者仿照经典的continual learning的实验方法来设计了实验:将样本分两批,使用模型依次进行训练,并记录模型在两批样本上的分类准确率

continual learning

上图最左边,是使用一个数据集中随机挑选的两部分来轮流训练。我们发现,即使两个task都来自同一个数据分布,灾难性遗忘也可能发生!模型太健忘了。

右边的两个图,则是使用unforgettable和forgettable样本作为两个数据集来依次训练,可以发现两个结论:

  • 在容易遗忘的样本上训练完之后,再去难忘的样本上训练,灾难性遗忘很严重(刚刚把易遗忘的样本学会,就一下子忘记了)
  • 在难忘的样本上训练完之后,再去易遗忘的样本上训练,灾难性遗忘的现象很轻微。

6. 我们可以丢掉很多样本,还能保持泛化性能

在上面的实验我们可以看出,学习forgettable examples对于unforgettable examples上的泛化性能似乎影响不大,而反过来就影响很大。借助开头的那个图来理解一下:

这意味着forgettable examples的分布能够比较好地涵盖unforgettable examples的分布,这样才会使得学习新的样本对原来的decision boundary不会有太大改变。

所以,从这个角度看的话,forgettable examples比unforgettable examples蕴含了更多的信息,样本在训练中被遗忘的次数越多,它对分类任务的作用可能越大

因此,我们可以大胆假设,是不是我把unforgettable examples丢掉一大批,都不会怎么影响模型的性能呢?作者做了如下的样本丢弃实验:

removing unforgettable examples

左图中的绿线和蓝线分别代表按照被遗忘次数排序的样本和随机排列的样本,不断增大丢弃比例后的结果。可以发现,在CIFAR-10数据集中,我们可以把前35%最少遗忘的样本丢掉,只损失0.2%的准确率

右图则是同样去除5000个样本,但是改变这5000个样本中平均被遗忘次数。可以发现,大体上,包含的forgettable examples越多,效果越差。但是有意思的是存在一个明显的拐点,当forgettable examples达到一定比例时,效果又会抬升一点。作者解释,这说明数据集中可能存在某些异常点或者错误标注的样本(outliers or mislabeled examples),把他们去掉了对模型有好处,但这些样本往往被遗忘次数也很多。

removing unforgettable examples

这个图则对比了三个数据集,对比发现MNIST,permuted-MNIST和CIFAR-10可以分别移除高达80%,50%,30%的训练样本且几乎不影响性能。

7. 样本遗忘现象的稳定性

我们肯定还会关心这种样本遗忘现象,换了随机种子,换了模型,结果会不会差别很大,还是说,(不)容易遗忘的样本,换了模型和种子都依然(不)容易遗忘?

作者对此都做了实验探究,首先,使用了10个不同seed,对所有样本的number of forgetting进行统计,然后彼此之间计算排序的Pearson相关系数,发现高达89.2%,所以不同seed下,样本的遗忘现象是十分类似的。

然后,作者探究了在不同的训练阶段(不同的epoch时)的遗忘情况的差异,见下图最左边,实验表明,训练到75轮以后,样本遗忘的情况就基本稳定了。

(中间那个图我看不懂,就不讲了)

最右边那个图,是使用ResNet18来对forgetting events进行统计,然后使用这个统计结果,不断删减训练样本,在更大型的模型WideResNet上进行训练的结果,发现依然可以删除30%的数据还能保持性能基本不变,这说明,我们可以使用轻量的模型进行遗忘现象的统计,来辅助重型模型的训练。

总之,你如何训练(超参数、模型架构等)对遗忘现象的统计结果的影响不大,遗忘现象反映的是数据集本身的特点。

五、总结& 思考

写作上:

读到这里,我们应该可以发现,这就是一个对模型训练过程中的一些现象进行了一系列简单的统计,并没有什么技术含量,但是读完的感觉,却让我们大呼过瘾,原来深度学习这个黑箱子里还发生了这么多有趣的事情!

这篇文章,让我看到了搞深度学习的科研的另一种可能,我们不一定要设计复杂的模型,要提出什么深刻的数学理论才能做出好的研究,像本文这种对模型的行为的观察、对数据集特点的分析,也可以做出好的研究,并给后续的研究者提供很多经验和思考。

本文虽然像一个实验报告,使用的统计手段也很简单,但是本文设计实验的方法、如何从各种角度去对一个现象进行观测,是很值得我们学习的。

样本遗忘带来的启发

样本遗忘,以及灾难性遗忘,告诉我们神经网络本身存在一定缺陷,没法将学到的知识进行比较好的保留,知识很容易被覆盖。这明显跟人类学习过程不太一样,新的知识一般不会对曾经学过的知识进行巨大冲击,而是融合。所以这对于我们设计神经网络,设计训练方法,应该有很大启示,在continual learning领域应该已经有丰富的工作来试图解决这方面问题。

另外,样本遗忘现象本身,也可以帮助我们认识数据集,这对于Data-centric AI领域的研究应该也有很大帮助。


如果觉得有所收获的话

大家就点一个吧 :)

2022年的第3/52篇原创笔记 和我一起挖掘有趣的AI研究吧!

0 人点赞