作者:一元,原文摘自Google AI
The Deep Bootstrap Framework: Good Online Learners are Good Offline Generalizers(ICLR21)
理解深度学习的泛化性是一个尚未解决的基本问题。为什么在有限的训练数据集上优化模型能在一个hold-out的测试集中取得良好的性能?这一问题在机器学习中已经被研究了将近50多年。现在存在非常多数学工具可以帮助研究人员理解某些模型中的泛化性能。但是很不幸的是,大多数现有理论在应用于现代深网络时都失败了——它们在现实环境中既空洞又不具有预测性。
对于过长度参数化的模型,理论与实践之间的差距是巨大的,对于理论上有能力过拟合其训练集的模型,但在实践中却往往没有。
我们提出了一个新的框架,通过将泛化性与在线优化领域相结合来解决这个问题。在一个典型的设置中,一个模型在一组有限的样本上训练,这些样本被多个epoch重复使用。但在在线优化中,模型可以访问无限的样本流,并且可以在处理该样本流时进行迭代更新。
我们发现在无限数据上快速训练的模型与在有限数据上训练的模型具有相同的泛化能力。这种联系为实践中的设计选择带来了新的视角,并为从理论角度理解泛化奠定了路线图。
1. 关于Deep Bootstrap框架
Deep-Bootstrap框架的核心思想是将存在有限训练数据的现实世界与存在无限数据的“理想世界”进行比较。我们将其定义为:
- Real World(N,T): 在某个分布中的N个训练样本上训练模型,对于T个minibatch随机梯度下降(SGD)步,在多个epoch上重复使用相同的N个样本。这相当于在经验损失(训练数据损失)上运行SGD,属于监督学习中的标准训练过程。
- Ideal World(T): 在T步中训练相同的模型,但是在每个SGD步中使用来自分布的全新样本。也就是说,我们运行完全相同的训练代码(相同的优化器、学习速率、batch-size大小等),但在每个epoch中随机采样一个新的训练集,而不是重用样本。在这个理想的世界环境中,有一个有效的无限“训练集”,训练误差和测试误差之间没有区别。
在先验上,人们可能会认为现实世界和理想世界彼此无关,因为在现实世界中,模型看到的是有限数量的分布示例,而在理想世界中,模型看到的是整个分布。但在实际应用中,我们发现真实模型和理想模型实际上存在着相似的检验误差。
为了量化这一观察结果,我们通过创建一个新的数据集(我们称之为CIFAR-5m)来模拟一个理想的世界环境。我们在CIFAR-10上训练了一个生成模型,然后用它生成了约600万张图像。选择数据集的规模是为了确保从模型的角度来看它“实际上是无限的”,这样模型就不会对相同的数据进行重采样。也就是说,在理想世界中,模型看到的是一组全新的样本。
下图显示了几种模型的测试误差,比较了它们在真实环境(即重复使用的数据)和理想环境(“新”数据)中接受CIFAR-5m数据训练时的性能。蓝色实线显示了现实世界中的ResNet模型,该模型使用标准CIFAR-10超参数在50K样本上训练100个epoch。蓝色虚线显示了理想世界中的相应模型,在一次过程中对500万个样本进行了训练。令人惊讶的是,这些世界有着非常相似的测试错误——在某种意义上,模型“不在乎”它看到的是重复使用的样本还是新的样本。
这也适用于其它的架构,例如MLP、Vision Transformer,以及架构、优化器、数据分布和样本大小的许多其他设置。这些实验为泛化提供了一个新的视角:快速优化(在无限数据上)和良好的泛化(在有限数据上)模型。例如,ResNet模型比MLP模型在有限数据上的泛化效果更好,但这“是因为”即使在无限数据上,它的优化速度也更快。
2. 从优化行为理解模型的泛化性
我们核心的观察结果是,真实世界和理想世界的模型在测试误差中始终保持接近,直到真实世界收敛(<1%的训练误差)。因此,人们可以通过研究模型在理想世界中的相应行为来研究现实世界中的模型。
这也意味着模型的泛化可以从两个框架下的优化性能来理解:
- 在线优化:理想世界测试误差减少的速度有多快;
- 离线优化:真实世界的训练误差收敛速度有多快;
因此,为了研究泛化,我们可以等价地研究上述两个术语,这在概念上可能更简单,因为它们只涉及优化问题。基于这一观察,好的模型和训练过程是:
- 在理想世界中快速优化;
- 在现实世界中不会太快地优化模型;
深度学习中的所有设计选择都可以通过它们对这两个terms的影响来看待。例如,一些进展,如卷积,skpi连接和预训练主要通过加速理想世界的优化来进行帮助,而其它的进步,如正则化和数据增强,则主要通过减速现实世界的优化来帮助。
3. Applying the Deep Bootstrap Framework
研究人员可以使用Deep-Bootstrap框架来研究和指导深度学习中的设计选择。其原理是:当一个人在现实世界中做出影响泛化的改变(结构、学习率等),他应该考虑它对:
- 测试误差的理想世界优化(越快越好);
- 训练误差的现实世界优化(越慢越好)。
例如,在实践中经常使用预训练来帮助在小数据区域中的模型泛化。然而,人们对预训练能带来帮助的原因仍知之甚少。
我们可以使用Deep Bootstrap框架来研究这一点,方法是观察上述(1)和(2)项的预训练效果。我们发现预训练的主要效果是改善理想世界的优化,
- 预训练使网络成为在线优化的“快速学习者”。
因此,在理想世界中,预训练模型的改进泛化几乎被其改进的优化所准确捕获。下图显示了在CIFAR-10上训练的Vision-Transformers (ViT)的情况,比较了从头开始的训练和在ImageNet上的预训练。
我们也可以使用这个框架来研究数据扩充。理想世界中的数据扩充对应于将每个新样本都会扩充一次,而不是将同一样本扩充多次。这个框架意味着良好的数据扩充是指
- 不会显著损害理想世界优化(即,扩充样本看起来不会太“偏离分布”);
- 抑制真实世界优化速度(因此真实世界需要更长时间来适应其训练集合)。
数据扩充的主要好处是通过第二项,延长了实际优化时间。至于第一项,一些激进的数据扩充(混合/剪切)实际上会损害理想世界,但这种效果与第二项相形见绌。
4. 小结
Deep-Bootstrap框架为深度学习中的泛化现象和经验现象提供了一个新的视角。希望它可以应用到理解未来深度学习的其它方面。特别有趣的是,泛化可以通过纯粹的优化考虑来表征,这与理论上许多流行的方法形成了对比。最关键的是,我们可以同时考虑在线和离线优化,它们都不是非常充分,但它们共同决定了泛化。
Deep-Bootstrap框架还可以解释为什么Deep-learning对许多设计选择相当鲁棒:许多类型的体系结构、损失函数、优化器、规范化和激活函数都可以很好地泛化。这个框架提出了一个统一的原则:从本质上讲,任何在在线优化环境下运行良好的选择,也会在离线环境下得到很好的泛化。
最后,现代神经网络既可以参数化过度(例如,针对小数据任务训练的大型网络),也可能参数化不足(例如,OpenAI的GPT-3、Google的T5或Facebook的ResNeXt WSL等等)。Deep-Bootstrap框架则表明在线优化是两种模式成功的关键因素。
参考文献
- A New Lens on Understanding Generalization in Deep Learning
- The Deep Bootstrap Framework: Good Online Learners are Good Offline Generalizers(ICLR21)https://arxiv.org/pdf/2010.08127.pdf