FlashAttention图解(如何加速Attention)

2023-10-30 19:24:05 浏览数 (2)

作者丨Austin

来源丨https://zhuanlan.zhihu.com/p/626079753

编辑丨GiantPandaCV


Motivation

当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。

标准Attention的中间结果S,P(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为

。本文分析:

FlashAttention: 对HBM访问的次数为

Attention: 对HBM访问的次数为

往往

(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward Backward的GFLOPs、HBM、Runtime对比(A100 GPU):

GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。

综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。

阅读本文需要了解的符号定义:

Method

Attention

计算流程图如下:

FlashAttention

只是部分结果,如下图所示,外循环 j 是横向(特征维 d )移动的,内循环 i 是纵向(序列维 N )移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。

作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。

Theorem 1. FlashAttention的FLOPs为

,除了input和output,额外需要的内存为


- The End -

0 人点赞