FlashAttention:快速且内存高效的准确注意力机制

2024-07-04 11:17:13 浏览数 (1)

在深度学习领域,注意力机制是提高模型性能的关键组件。然而,传统的注意力机制在长序列处理时会消耗大量内存和计算资源。为了解决这个问题,Tri Dao等人提出了FlashAttention,这是一种快速且内存高效的注意力机制。本文将介绍FlashAttention及其改进版FlashAttention-2的核心概念、安装方法和使用示例。

论文介绍

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • 作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
  • 论文链接: arxiv.org/abs/2205.14135

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  • 作者: Tri Dao
  • 论文链接: flash2.pdf

安装和特性

环境要求

  • CUDA: 11.6及以上
  • PyTorch: 1.12及以上
  • 操作系统: Linux(从v2.3.2开始有部分Windows的正面反馈,但Windows编译仍需更多测试)

我们推荐使用Nvidia的PyTorch容器,其中包含安装FlashAttention所需的所有工具。

安装步骤

  1. 确保已安装PyTorch
  2. 安装packagingpip install packaging
  3. 安装ninja并确保其正常工作:ninja --version && echo $?应返回退出码0。如果未返回0,重新安装ninja:pip uninstall -y ninja && pip install ninja
使用pip安装
代码语言:javascript复制
pip install flash-attn --no-build-isolation
从源码编译
代码语言:javascript复制
python setup.py install
控制并行编译任务数(适用于RAM少于96GB且有多个CPU核心的机器)
代码语言:javascript复制
MAX_JOBS=4 pip install flash-attn --no-build-isolation

使用示例

FlashAttention主要实现了缩放点积注意力(softmax(Q @ K^T * softmax_scale) @ V)。以下是使用FlashAttention的核心函数:

代码语言:javascript复制
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

# 当Q, K, V已堆叠为一个张量时,使用flash_attn_qkvpacked_func
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                                window_size=(-1, -1), alibi_slopes=None, deterministic=False)

# 直接使用Q, K, V时,使用flash_attn_func
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                      window_size=(-1, -1), alibi_slopes=None, deterministic=False)

参数说明

  • qkv: (batch_size, seqlen, 3, nheads, headdim)格式的张量,包含Q, K, V
  • dropout_p: float,Dropout概率
  • softmax_scale: float,softmax前QK^T的缩放比例,默认为1 / sqrt(headdim)
  • causal: bool,是否应用因果注意力掩码(如用于自回归建模)
  • window_size: (left, right),如果不为(-1, -1),则实现滑动窗口局部注意力
  • alibi_slopes: (nheads,)或(batch_size, nheads),fp32。对查询i和键j的注意力分数加上一个偏置(-alibi_slope * |i - j|)
  • deterministic: bool,是否使用确定性实现的反向传播(略慢且使用更多内存)

性能表现

加速效果

FlashAttention在A100 80GB SXM5 GPU上使用FP16/BF16格式时的加速效果如下:

  • Head Dimension: 64或128
  • Hidden Dimension: 2048(即32或16个heads)
  • Sequence Length: 512, 1k, 2k, 4k, 8k, 16k
  • Batch Size: 16k / seqlen

内存节省

FlashAttention在处理较长序列时能显著节省内存。与标准注意力机制内存使用随序列长度二次增长不同,FlashAttention的内存使用线性增长。在序列长度为2K时可节省10倍内存,4K时可节省20倍内存。

完整模型代码和训练脚本

已发布了完整的GPT模型实现,并提供了其他层(如MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。整体上,训练速度较基线实现(如Huggingface实现)提高3-5倍,达到每A100 225 TFLOPs/sec,相当于72%的模型FLOPs利用率。

FlashAttention 更新日志

2.0:完全重写,速度提升2倍

FlashAttention在2.0版本中进行了完全重写,速度提升了两倍。本次更新引入了多个更改和改进,包括一些函数名称的更改以及在输入具有相同序列长度的情况下简化了使用方式。 FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本文详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。

算法

FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理matmul操作(尤其是在FP16/BF16格式下)时性能显著优化。该优化的目标是通过尽可能多地执行matmul操作来最大化GPU的吞吐量。

前向传播
  1. 在线Softmax技巧:FlashAttention-2对在线Softmax计算进行了修改,以最小化非matmul浮点操作:
    • 避免通过 diag(ℓ(2))^-1 重新缩放输出更新的两个项。
    • 维持一个“未缩放”的O(2)版本,并保留统计信息 ℓ(2)。
    • 仅在循环结束时,通过 diag(ℓ(last))^-1 缩放最终的O(last)以获得正确的输出。
  2. 最大化matmul FLOPs:为了最大化GPU的性能,FlashAttention-2重点优化了matmul操作,因为现代GPU上的专用单元(如Tensor Cores)在这些操作上表现出色。以Nvidia A100 GPU为例,其FP16/BF16 matmul的理论吞吐量可以达到312 TFLOPs/s,而非matmul FP32的吞吐量仅为19.5 TFLOPs/s。因此,FlashAttention-2通过优化算法,尽可能地减少非matmul操作,从而保持高吞吐量的执行效率。
  3. 算法细节:FlashAttention-2的前向传播通过以下步骤实现: