斯坦福大学CS博士新作:新型Attention提速2-4倍,BERT单节点训练最快

2022-06-17 18:11:09 浏览数 (1)

机器之心报道

编辑:陈萍

FlashAttention 是一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法。

一种快速、内存高效的注意力算法来了,被命名为 FlashAttention。通过减少 GPU 内存读取 / 写入,FlashAttention 的运行速度比 PyTorch 标准注意力快 2-4 倍,所需内存减少 5-20 倍。

这项研究由斯坦福大学、纽约州立大学布法罗分校的研究者共同完成。共同一作是两位斯坦福计算机博士生 Tri Dao 和 Dan Fu。

下面我们介绍一下论文具体内容。

FlashAttention

Transformer 已然成为自然语言处理和图像分类等应用中最广泛使用的架构。随着研究的不断前进,Transformer 尺寸变得越来越大、层数也越来越深,但是给 Transformer 配备更长的上下文仍然很困难,因为 Transformer 核心自注意力模块的时间复杂度以及内存复杂度在序列长度上是二次方的。

有研究者提出一些近似注意力的方法,旨在减少注意力计算和内存需求。这些方法包括稀疏近似、低秩近似以及它们的组合。从序列长度来看,尽管这些方法可以将计算降低到线性或接近线性,但它们并没有显示出针对标准注意力的 wall-clock 加速,因而没有被广泛使用。这其中一个主要原因是这些研究专注于减少 FLOP(这可能与 wall-clock 速度无关)并且倾向于忽略来自内存访问 (IO) 的开销。

在本文中,该研究认为应该让注意力算法具有 IO 感知——即考虑显存级间的读写。现代 GPU 计算速度超过了内存速度,transformer 中的大多数操作都被内存访问所阻塞。IO 感知算法对于类似的内存绑定操作至关重要,这种重要性体现在当读写数据占据很大运行时——例如数据库连接、图像处理、数值线性代数等。然而,用于深度学习的常见 Python 接口,如 PyTorch 和 Tensorflow,不允许对内存访问进行细粒度控制。

论文地址:https://arxiv.org/pdf/2205.14135.pdf

GitHub 地址:https://github.com/HazyResearch/flash-attention

该研究提出了一种新的注意力算法 FlashAttention,它可以使用更少的内存访问来计算精确的注意力。FlashAttention 旨在避免从 HBM(High Bandwidth Memory)中读取和写入注意力矩阵。这需要做到:(i) 在不访问整个输入的情况下计算 softmax reduction;(ii) 在后向传播中不能存储中间注意力矩阵。

该研究采用两种成熟的技术来应对这些挑战:

(i) 该研究重组注意力计算,将输入分成块,并在输入块上进行多次传递,从而逐步执行 softmax reduction(也称为 tiling);

(ii) 该研究存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从 HBM 中读取中间注意力矩阵的标准方法更快。

该研究在 CUDA 中实现 FlashAttention ,以达到对内存访问的细粒度控制,并将所有注意力操作融合到一个 GPU 内核中。即使由于重新计算导致 FLOPs 增加,但其运行速度更快(在 GPT-2 上高达 7.6 倍,图 1 右图)并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。

该研究分析了 FlashAttention 的 IO 复杂度,证明它需要

0 人点赞