MetaAI 提出CRINGE损失方法,引入badcase 提升模型训练效果

2022-12-06 15:32:19 浏览数 (1)

引言

模型训练既要让模型知道该做什么,也要让模型知道不该做什么!目前绝大数的语言模型都是通过标注好的数据进行训练,并希望模型输出结果像标注的正例数据一样好,然而却忽略了负例数据的重要性。因为模型训练仍然需要少量的负面数据来提高模型效果。今天给大家分享的这篇文章就从这个角度出发,构建了一个新的损失函数将正例数据和负例数据融到一起进行模型训练。

背景介绍

 近年来,随着大型 Transformer 的兴起,语言模型和会话机器人的功能都变得更加强大,并且已经达到了日常交互的水平。然而,标准语言模型仍然会存在很多问题,比如不能识别用户意图,生成句子缺乏连贯性、而且也有可能生成句子具有偏见等。目前,越来越多的研究人员正在研究超出标准语言建模目标的模型训练方法,即通过将失败案例的信息纳入训练目标,进而提高模型效果。

 为此,在这项工作中,本文研究了这样一种方法CRINGE ,其中训练集包括一组给定的正例序列(通常用于语言模型训练)和一组负例序列(给出模型不应该生成的prompt的补全)

模型方法

CRINGE损失

 本文提出的CRINGE (contrast Iterative Negative GEneration)损失,「它是一种对同时包含正负序列的数据进行训练的方法」。对于正的例子,使用极大似然方法。通过将序列中的每个标记与语言模型的预测之一进行对比来训练反例,该方法易于实现与现有的方法相比性能良好。

 上图展示了单个负序列的CRINGE损失的概念示意图。其中,CRINGE 损失通过惩罚负样本的输出序列(以红色显示)来发挥作用。对于每个负输出标记,从语言模型中采样一个正预测以与之对比。负序列要么来自(i)人为注释,要么来自(ii)分类器(例如,从人为注释数据训练而来),该分类器可用于迭代地标记模型生成,并将 CRINGE 损失应用于这些示例。正序列使用常用的语言建模目标进行训练。

 更确切的说,最终优化目标由两个项组成:正序列的CrossEntropy项和负序列的CRINGE项。前者作为标准使用,例如,对于来自正序列

x

的Token

x_t

则有:

 对于负序列,将序列中的每个Token与一个正的Token进行对比。在训练过程中,通常会被提供一个负序列,但不知道序列中给定的负令牌是什么。因此,本文方法从模型的当前top-k预测中进行采样(如果它在top-k中,则省略负向标记,以便不选择相同的负向标记作为正例)。在这里,根据通过softmax在模型预测的top-k逻辑上构建的分类分布进行采样。因此,对比损失为:

 最后,为了训练正例和负例的序列数据,取两个损失的加权和如下所示:

「迭代训练」 本文提出的 CRINGE 损失函数使我们能够有效地在正例和负例上训练模型。这开辟了「通过从其自身生成的分类中学习并应用相同的损失来迭代改进模型」的可能性。我们遵循一个简单的策略,完成模型训练,在训练集上标记模型的生成,然后使用增强的训练集重复该过程。虽然模型生成标签可能通过人工审查在持续的人在环方法中获得,但我们提出在原始正面和负面示例上训练分类器,并将其用于自动标记示例,类似于在强化学习中使用奖励模型。因此,本文使用以下过程:

  • (1) 使用数据集 D 微调模型;
  • (2)根据原始训练示例上下文生成附加序列;
  • (3) 标记模型的生成(正面或负面)并将它们作为附加训练示例添加到数据集D;
  • (5) 使用更新的数据集重复该过程。 以上这种方法可以应用多轮。实验结果发现即使只应用了两次训练迭代,它也可以带来显着的性能提升。

实验结果

 在安全生成、矛盾规避和开放域对话任务的三个不同实验中展示了这种方法的有效性。我们的模型优于多个强大的基线,并且在概念上很简单,易于训练和实施。

1、测试安全生成和矛盾规避任务的性能。

2、开放域对话任务上的实验结果。

论文&&源码

Paper:https://arxiv.org/pdf/2211.05826.pdf

0 人点赞