监督学习是机器学习 (ML) 的一种流行方法,其中使用已针对手头任务进行适当标记的数据来训练模型。普通监督学习训练独立同分布(IID)。
所有的训练样本都来自一组固定的类。该模型可以在整个训练阶段访问它们。另一方面,连续学习通过依次呈现不同的分类任务来解决在变化的数据分布上训练单个模型的问题。这对于自治代理处理和解释现实世界场景中的连续信息流尤其重要。
考虑两个任务来展示监督学习和持续学习之间的区别:
(1)对猫和狗进行分类
(2)对熊猫和考拉进行分类。该模型从两个任务中获得训练数据,并将其视为监督学习中的单个 4 类分类问题,它采用 IID。但是,在持续学习中,这两个任务是按顺序呈现的,模型只能访问当前任务的训练数据。因此,此类模型在以前的任务上容易出现性能下降,称为灾难性遗忘。
主流解决方案通过将以前的数据存储在“预演缓冲区”中并将其与当前数据相结合来训练模型来解决灾难性遗忘问题。
但是,这些解决方案的性能在很大程度上取决于缓冲区大小,并且在某些情况下,由于数据隐私问题,这可能是不可能的。另一条工作线创建特定于任务的组件,以避免干扰其他任务。然而,这些方法经常假设测试时的任务是已知的,但情况并非总是如此,并且它们需要大量参数。这些方法的局限性提出了终身学习的基本问题。除了简单地缓冲以前的数据之外,是否有可能拥有一个更高效、更紧凑的内存系统?是否可以在不知道任务身份的情况下为随机样本选择相关的知识组件?
“Learning to Prompt”是一种新颖的持续学习框架,灵感来自自然语言处理提示技术(L2P)。不是重新学习每个顺序任务的所有模型权重,而是提供可学习的任务相关“指令”,即提示,以使用可学习的提示参数池引导预训练的主干模型通过顺序训练。L2P 适用于各种具有挑战性的持续学习设置,并且在所有基准测试中始终优于以前的最先进方法。它在性能方面优于基于排练的方法,同时内存效率也更高。最重要的是,L2P最先在持续学习的背景下提出了提示的概念。
与使用预演缓冲区按顺序使整个或部分模型权重适应任务的传统方法相比,L2P 使用单个冻结的主干模型并学习提示池来有条件地指示模型。术语“模型 0”表示主干模型在开始时是固定的。
“基于提示的学习”使用给定预训练 Transformer 模型的固定模板修改原始输入。假设给情绪分析任务“我喜欢这只猫”的信息。基于提示的方法会将输入更改为“我喜欢这只猫。看起来X”,其中“X”是要预测的空槽(例如,“nice”、“可爱”等),“看起来X”是所谓的提示。在输入中添加提示可以调节预训练模型以解决许多下游任务。在迁移学习设置下,提示调优会在输入嵌入之前添加一组可学习的提示,以指示预训练的主干学习单个下游任务,而设计固定提示则需要先验知识和反复试验。
L2P 在持续学习场景中维护了一个可学习的提示池,其中提示可以灵活地分组为子集以协同工作。每个提示都与通过减少匹配输入查询特征之间的余弦相似度损失而发现的键相关联。然后,查询函数使用这些键根据输入特征动态查找任务相关提示的子集。查询函数在测试时将输入映射到提示池中最接近的前 N 个键,然后将相关的提示嵌入馈送到模型的其余部分以生成输出预测。在训练期间使用交叉熵损失来优化快速池和分类头。
直观地说,相似的输入示例倾向于选择相似的提示集,反之亦然。因此,经常共享的提示编码更多通用知识,而其他提示编码更多特定于任务的知识。此外,提示存储高级指令,同时冻结低级预训练表示,即使没有排练缓冲区也能减少灾难性遗忘。实例查询机制消除了了解任务身份或边界的需要,允许这种方法解决任务不可知的持续学习的研究不足的问题。
在具有代表性的基准上,使用 ImageNet 预训练的视觉转换器 (ViT) 在各种基线方法中评估了 L2P 的有效性。简单基线,在下图中称为 Sequential,是指在所有任务上按顺序训练单个模型。EWC 模型包含一个正则化项以减少遗忘,而 Rehearsal 模型将先前的示例存储在缓冲区中,以便与当前数据进行混合训练。准确度和平均差异是在训练期间达到的最佳准确度与所有任务的最终准确度之间测量的,以评估整体持续学习性能,称为遗忘。L2P 在这两个指标上都优于 Sequential 和 EWC 方法。
值得注意的是,L2P 优于 Rehearsal 方法,后者采用额外的缓冲区来保存以前的数据。因为 L2P 方法与 Rehearsal 是正交的,所以如果它也使用排练缓冲区,它的性能可能会得到更大的提高。在准确性和遗忘方面,L2P 优于基线方法。准确率是所有任务的平均准确率,而遗忘是训练期间达到的最佳准确率与所有任务的最终准确率之间的平均差。
提示选择结果是根据实例查询策略在两个不同的基准上绘制的,一个具有相似的任务,另一个具有混合任务。根据研究结果,L2P 通过使用更多共享提示来鼓励相似任务之间更多的知识共享,而通过使用更多特定于任务的提示来鼓励不同任务之间的知识共享更少。
L2P 是一种解决持续学习中关键挑战的新方法。L2P 不需要预演缓冲区或已知的任务标识即可在测试时实现高性能。此外,它可以处理各种复杂的持续学习场景,包括有问题的与任务无关的设置。
论文:
https://arxiv.org/pdf/2112.08654.pdf
Github:
https://github.com/google-research/l2p
来源:
https://ai.googleblog.com/2022/04/learning-to-prompt-for-continual.html