EMNLP2023 | “魔改Transformer”,AWS提出:MASFormer,计算成本降低75%!

2023-11-03 12:21:07 浏览数 (2)

引言

降低Transformer的计算成本,提高Transformer的长序列扩展能力,一直是学术研究的重点。例如:伯克利提出的Ring Attention、Paged Attention、普渡提出的SRformer等,更有研究人员提出了替代Transformer方案,例如:斯坦福提出的Backpack、Monarch Mixer架构、清华提出的RetNet架构。

今天给大家分享的这篇文章出自EMNLP2023,作者并没有跳出Transfomer框架,而是"魔改"了Transformer,提出了MASFormer,它是一种易于实现的具有混合注意力跨度的Transformer变体。实验表明,仅包含 1.3B 参数的解码器 MASFormer 模型降低计算成本高达75%,并且性能与普通Transformer基本保持一致。

Paper:https://arxiv.org/pdf/2310.12442.pdf

背景介绍

Transformer在自然语言建模(NLM)、自然语言生成(NLG)和自然语言理解(NLU)等各种自然语言处理任务中都表现出了卓越的性能。它们都是利用Attention来计算输入Token之间的依赖关系。

在一些实际应用中,我们经常需要Transformer模型来处理长序列输入。例如,在机器人聊天场景下,机器人系统会根据与用户长期交流的上下文文本来生成回复;在学术论文、学术报告场景下,需要模型接受长序列的输入来生成全面的摘要,否则模型经常会错过重要信息。

传统的Transformer模型会充分考虑Token之间的依赖关系,这使得时间和空间复杂度为序列长度的二次方。所以,当面对长序列输入的时候会产生大量的时间消耗,尤其在反向传播过程中会产生大量的内存消耗。例如,当序列长度为8k时,具有250M参数的Transformer模型会消耗超过80G的GPU内存。

为了扩展Transformer可支持序列的长度,研究人员们提出了各种方法来降低计算复杂度。一种方法是稀疏注意,它根据预定义的稀疏性模式限制每个令牌只关注令牌的一个子集。例如,块稀疏注意将输入序列划分为几个块,只进行块内注意,如下图所示。

此外,滑动窗口注意允许每个令牌在滑动窗口内关注其相邻Token,如下图所示。这些方法虽然降低了Attention的复杂性,但不能充分捕获远程依赖关系。

为了弥补远程依赖关系的缺失,LongT5引入了全局Token,通过对每个令牌块进行平均池化来获得全局令牌。然而,块池化操作可能会削弱关键令牌的信号,并阻止检测到远程依赖关系。

除了这些方法之外,状态空间模型(SSM)还预先指定全局依赖模式,以便仅捕获远程依赖关系。这些模型可以看作是专门设计了固定权值的线性递归神经网络。然而,状态空间方法实现起来比较复杂,并且在反向传播期间经常遇到计算不稳定,尤其是在扩大模型尺寸时。

以上所有方法都有一个共同点,那就是每一层都使用相同的Attention机制。而本文打破这一观念,提出了一种新的Transformer架构:MASFormer(即混合Attention跨度Transformer)。

MASFormer

MASFormer仅在Transformer层的子集上利用全注意力,而在其余层上采用块稀疏注意力,MASFormer 的架构如下图所示:

其中,首先是选择全注意力来编码长序列,主要原因是:

 (1)与稀疏注意相比,全注意力在捕获长序列依赖上表现出了卓越的性能;

 (2) 全注意力不需要复杂的实现,因此与 SSM 相比在计算上是稳定的;

 (3) 全注意力与现有的预训练Transformer模型兼容,为此能够进行持续训练;

基于对块稀疏注意力和全注意力在语言建模和摘要任务上的表现,如下图所示。

可以发现,给定长序列输入,稀疏注意力通常不足以捕获超出其注意力范围的远程依赖关系,所以表现结果较差 为了解决这个问题,可以增加注意力跨度或切换到充分注意力,以提高模型捕获复杂依赖关系的能力。虽然提高了模型性能,但会带来较高的计算成本。

面对计算成本和模型性能之间的权衡。MASFormer 提供了另一种解决方案。MASFormer不是均匀地增加注意力跨度,而是通过为

l

层配备全注意力, 在其余层,MASFormer 利用小尺寸

m

的块注意力,从而产生

(L − l)mn ln^2

的受控注意力成本。

这种设计的灵感来自于NLP上下文数据大都表现出的局部引用现象,远程依赖性并不常见。因此,没有必要在每一层都增强注意力。相反,几层充分关注就足以捕获不常见的远程信号,大多数层可以保持较小的注意力跨度,以充分提取局部依赖性并控制注意力成本。

值得注意的是,MASFormer 可以实现与完全注意力相当的性能,同时大幅降低计算成本。因此,通过混合不同的注意力跨度,MASFormer 在计算成本和模型性能之间取得了更好的平衡。

此外,MASFormer 还提供额外的实施优势。由于使用相同的注意力函数,MASFormer 易于实现并与现有的预训练模型兼容。我们可以通过改变注意力模式来在预训练的 Transformer 上构建 MASFormer,这不涉及对模型架构和预训练权重的修改。

实验结果

下表比较了ArXiv和PubMed测试集上的perplexity。结果表明,在

l =4

层完全注意力的情况下,MASFormer 实现了与所有完全注意力相当的性能,同时降低了 72% 的注意力成本。

下表显示了给定不同序列长度,每种方法的perplexity变化。可以看出,MASFormer和Full Attention在较长的文档上表现出更好的性能,这表明增加上下文长度可以提高它们的预测性能。

下表展示了QMSUM、ArXiv和GovReport在不同注意力成本下的微调结果。在相似的注意力成本下,MASFormer显著优于稀疏注意力变体。

0 人点赞