论文标题:SSMix: Saliency-Based Span Mixup for Text Classification
论文链接:https://arxiv.org/pdf/2106.08062.pdf
论文代码:https://github.com/clovaai/ssmix
论文作者:{soyoungyoon etc.}
论文摘要
数据增强已证明对各种计算机视觉任务是有效的。尽管文本取得了巨大的成功,但由于文本由可变长度的离散标记组成,因此将混合应用于NLP任务一直存在障碍。在这项工作中,我们提出了SSMix,一种新的混合方法,其中操作是对输入文本执行的,而不是像以前的方法那样对隐藏向量执行的。SSMix通过基于跨度的混合,综合一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,我们实证验证了我们的方法在广泛的文本分类基准上优于隐藏级混合方法,包括文本隐含、情感分类和问题类型分类。
数据增强的效果已经在各种计算机视觉任务中被证实是有效的。尽管数据增强非常有效,由于文本是由变长的离散字符组成的,所以将mixup应用与NLP任务一直存在障碍。在本篇论文,作者提出了SSMix算法,一种针对输入文本增强的mixup算法,而非之前针对隐藏向量的方法。SSMix通过跨度混合( span-based mixing)在保留原始两个文本的条件下合成一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,论文验证了该算法在广泛的文本分类基准上优于隐藏级混合方法,包括文本推断、情感分类和问题类型分类任务。
算法简介
由于数据收集与标志的昂贵成本,数据增强在自然语言处理(NLP)中越来越重要。其中一些已往研究包括基于简单的规则和模型来生成类似的文本。比如通过标准方法或先进的训练方法与原始样本联合进行训练,也有基于混淆(mixup)插值文本和标签进行增强。
Mixup及其变体训练算法成为计算机视觉中常用的正则化方法,用来提高神经网络的泛化能力。混合方法分为输入级混合和隐藏级混合( hidden-level mixup),两者取决于混合操作的位置。输入级混合是一种比隐藏级混合更普遍的方法,因为它的简单性和能够捕获局部性,从而具有更好的准确性。
由于文本数据的离散性和可变的序列长度,在NLP中应用mixup比在计算机视觉中更具有挑战性和难度。因此,之前大多数关于文本混合的尝试将mixup应用于嵌入向量,如嵌入或中间表示。然而根据计算机视觉的增强直观感受,输入级混合一般比隐藏级混合有优势。这一动机鼓励作者对探究文本数据的输入级混淆方法。
在这项工作中,作者提出了SSMix(图1),一种新的输入级结合跨度(Span)的显著性混合数据增强法算法。首先,作者通过用另一个文本中的跨度替换连续的标记来进行混淆,这一灵感来自CutMixarXiv,在混合文本中保留两个源文本的位置。其次,选择一个要替换的跨度,并基于显著性信息进行替换,以使混合文本包含与输出预测更相关的标记,这在语义上可能很重要。文本的输入级方法不同于隐级混合方法,当当前的隐级混合方法线性插值原始隐向量,我们的方法在输入级上混合文本字符,产生非线性输出。同时,利用显著性值从每个句子中选择跨度,并离散地定义跨度的长度和混合比,这是与隐藏级别混合增强区别的地方。
SSMix已经通过大量的文本分类基准实验被证明是有效的。特别强调的是,论文证明了输入级混合方法一般要优于隐层混合方法。论文还展示了在进行文本混合增强的同时,在跨度水平上使用显著性信息和限制标记选择的重要性。
SSMix算法
SSMix基本原理为:给定两个文本
和
,通过将文本
的片段
替换为来自另一文本
的显著信息片段
生成得到新的文本
。同时,对于新文本
,基于两个文本标签
和
重新为新文本
设置一个新的标签
。最后可以使用这个生成的增强虚拟样本(
,
)来进行训练模型。
Saliency:显著性信息
Saliency衡量了文本数据的每个字符对最终结果预测的影响。在以往研究中基于梯度的方法被广泛用于显著性计算,文本同样计算了分类损失
相对于输入嵌入
的梯度,并使用其大小作为显著性:
。文中应用l2范数来获得一个梯度向量的大小,代表着每个字符的类似于PuzzleMix的显著性。
Mixing Text:文本合成
之前提到过,Mixing Text主要是是指两个文本序列
和
如何合成新的文本。大致思路是根据梯度显著性计算方法得到两个文本中每个字符的显著性分数,然后在文本
中选取一个显著性最低的片段
,长度为
,在文本
中选取一个显著性最低的片段
,长度为
。长度设置为
=
=
,其中
为mixup比例参数。最后生成新文本
w为
,其中
和
为原始文本
中替换片段
的左右的两部分。
Sample span length:相等片段长度
本文将原始(
)的长度和替换(
)跨度设置为相同的,主要原因是使用不同长度的span(片段)将导致冗余和语义不明确的mixup 转换。另外,计算不同长度的span之间的mixup 比列也过于复杂。在以往研究中也采用了这种相同大小的替换策略。在替换span长度相同的情况下,论文的SSMix算法能够使显著性的效果最大化。由于SSMix不限制字符的位置,可以同时选择最显著的span和被替换的最不显著片段。如图片1中,in this
在文本
中是不显著的,transcedent love
在文本
中是最显著的,那么可以用transcedent love
替换in this
。
Mixing Text:标签合成
作者将mixup 比列设置为:
由于λ是通过计算片段内的字符数量来重新计算的,因此它可能与λ0不相等。然后
的标签为:
算法1展示了如何利用原始样本对来计算增广样本的混合损失。公式中计算了增强输出logit相对于每个样本的原始目标标签的交叉熵损失,并通过加权和进行组合,因此SSMix算法与数据集标签个数是不相关的,在任何数据集上,输出标签比例是通过两个原始标签的线性组合来计算。
Paired sentence tasks:句子对任务
对于需要一对文本作为输入的任务,如文本隐含推断和相似性分类,SSMix以成对的方式进行混合,并通过聚合每个mixup结果中的标记计数来计算mixup比例。给定样本
,
,合成的新样本为
,mixup比例记为
,其中
和
为每个mixup操作中的替换片段。
如下图所示:
为 "Fun for only children."
为 "Fun foradults and children."
为 "Problems in data synthesis."
为 "Issues in data synthesis."
实验设置
实验数据集
论文实验数据集有文本分类和句子对分类任务:
对比实验
论文将SSMix与三个基线进行了比较:(1) standard training without mixup,(2)EmbedMixMix(3)TMix。
与基线和消融研究的实验结果进行了比较。所有的准确率值都是使用不同种子的5次运行的平均精度(%)。MNLI表示MNLI-不匹配的开发集的准确性。论文报告了GLUE的验证精度,TREC的测试精度,以及ANLI的有效(上)/测试(较低)精度,可以看出SSMix在大部分数据集效果要优于其他混合增强算法。
论文总结
- 与隐层混合方法相比,SSMix在具有足够数据量的数据集上充分证明了其有效性。由于SSMix是一个离散的组合,而不是两个数据样本的线性组合,它在一个合成空间上创建数据样本的范围大于隐藏级别的混合。论文假设,大量的数据有助于更好地在合成空间中进行表示。
- SSMix对于多个类标签数据集(TREC、ANLI、MNLI、QNLI)尤其有效。因此,在没有混合的训练条件下,SSMix在TREC-fine(47个标签)上的精度增益远高于TRECcrare(6个标签), 分别为3.56和 为0.52。具有多个总类标签的数据集增加了在混合源的随机抽样中被选择交叉标签的可能性,所以可以认为在这些多标签分类数据集中的混合性能会显著提高
- 在成对句子任务上具有显著优势,如文本隐含或相似性分类。现有的方法(隐藏层混合)在隐藏层上应用混合,而不考虑特殊的标记,即[SEP]、[CLS]。这些方法可能会丢失关于句子开头的信息或句子对的适当分离。相比之下,SSMix在应用混合时可以考虑单个字符的特性。 -SSMix 及其变体的消融研究结果表明,随着对片段约束和显著性信息的增加,性能有所提高。在混合操作中添加片段约束受益于更好的可定位能力,并且大多数显著的片段与相应的标签有更多的关系,而丢弃最小显著的片段,这些片段相对于原始标签在语义上不重要。其中,引入显著性信息对精度的贡献相对高于片段约束。
代码实现
代码语言:javascript复制import copy
import random
import torch
import torch.nn.functional as F
from .saliency import get_saliency
class SSMix:
def __init__(self, args):
self.args = args
def __call__(self, input1, input2, target1, target2, length1, length2, max_len):
batch_size = len(length1)
if self.args.ss_no_saliency:
if self.args.ss_no_span:
inputs_aug, ratio = self.ssmix_nosal_nospan(input1, input2, length1, length2, max_len)
else:
inputs_aug, ratio = self.ssmix_nosal(input1, input2, length1, length2, max_len)
else:
assert not self.args.ss_no_span
input2_saliency, input2_emb, _ = get_saliency(self.args, input2, target2)
inputs_aug, ratio = self.ssmix(batch_size, input1, input2,
length1, length2, input2_saliency, target1, max_len)
return inputs_aug, ratio
def ssmix(self, batch_size, input1, input2, length1, length2, saliency2, target1, max_len):
inputs_aug = copy.deepcopy(input1)
for i in range(batch_size): # cut off length bigger than max_len ( nli task )
if length1[i].item() > max_len:
length1[i] = max_len
for key in inputs_aug.keys():
inputs_aug[key][i][max_len:] = 0
inputs_aug['input_ids'][i][max_len - 1] = 102
saliency1, _, _ = get_saliency(self.args, inputs_aug, target1)
ratio = torch.ones((batch_size,), device=self.args.device)
for i in range(batch_size):
l1, l2 = length1[i].item(), length2[i].item()
limit_len = min(l1, max_len) - 2 # mixup except [CLS] and [SEP]
mix_size = max(int(limit_len * (self.args.ss_winsize / 100.)), 1)
if l2 < mix_size:
ratio[i] = 1
continue
saliency1_nopad = saliency1[i, :l1].unsqueeze(0).unsqueeze(0)
saliency2_nopad = saliency2[i, :l2].unsqueeze(0).unsqueeze(0)
saliency1_pool = F.avg_pool1d(saliency1_nopad, mix_size, stride=1).squeeze(0).squeeze(0)
saliency2_pool = F.avg_pool1d(saliency2_nopad, mix_size, stride=1).squeeze(0).squeeze(0)
# should not select first and last
saliency1_pool[0], saliency1_pool[-1] = 100, 100
saliency2_pool[0], saliency2_pool[-1] = -100, -100
input1_idx = torch.argmin(saliency1_pool)
input2_idx = torch.argmax(saliency2_pool)
inputs_aug['input_ids'][i, input1_idx:input1_idx mix_size] =
input2['input_ids'][i, input2_idx:input2_idx mix_size]
ratio[i] = 1 - (mix_size / (l1 - 2))
return inputs_aug, ratio
def ssmix_nosal(self, input1, input2, length1, length2, max_len):
inputs_aug = copy.deepcopy(input1)
ratio = torch.ones((len(length1),), device=self.args.device)
for idx in range(len(length1)):
if length1[idx].item() > max_len:
for key in inputs_aug.keys():
inputs_aug[key][idx][max_len:] = 0
inputs_aug['input_ids'][idx][max_len - 1] = 102 # artificially add EOS token.
l1, l2 = min(length1[idx].item(), max_len), length2[idx].item()
if self.args.ss_winsize == -1:
window_size = random.randrange(0, l1) # random sampling of window_size
else:
# remove EOS & SOS when calculating ratio & window size.
window_size = int((l1 - 2) *
self.args.ss_winsize / 100.) or 1
if l2 <= window_size:
ratio[idx] = 1
continue
start_idx = random.randrange(0, l1 - window_size) # random sampling of starting point
if (l2 - window_size) < start_idx: # not enough text for reference.
ratio[idx] = 1
continue
else:
ref_start_idx = start_idx
mix_percent = float(window_size) / (l1 - 2)
for key in input1.keys():
inputs_aug[key][idx, start_idx:start_idx window_size] =
input2[key][idx, ref_start_idx:ref_start_idx window_size]
ratio[idx] = 1 - mix_percent
return inputs_aug, ratio
def ssmix_nosal_nospan(self, input1, input2, length1, length2, max_len):
batch_size, n_token = input1['input_ids'].shape
inputs_aug = copy.deepcopy(input1)
len1 = length1.clone().detach()
ratio = torch.ones((batch_size,), device=self.args.device)
for i in range(batch_size): # force augmented output length to be no more than max_len
if len1[i].item() > max_len:
len1[i] = max_len
for key in inputs_aug.keys():
inputs_aug[key][i][max_len:] = 0
inputs_aug['input_ids'][i][max_len - 1] = 102
mix_len = int((len1[i] - 2) * (self.args.ss_winsize / 100.)) or 1
if (length2[i] - 2) < mix_len:
mix_len = length2[i] - 2
flip_idx = random.sample(range(1, min(len1[i] - 1, length2[i] - 1)), mix_len)
inputs_aug['input_ids'][i][flip_idx] = input2['input_ids'][i][flip_idx]
ratio[i] = 1 - (mix_len / (len1[i].item() - 2))
return inputs_aug, ratio