作者丨Austin
来源丨https://zhuanlan.zhihu.com/p/626079753
编辑丨GiantPandaCV
Motivation
当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。
标准Attention的中间结果S,P(见下文)通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为
。本文分析:
FlashAttention: 对HBM访问的次数为
Attention: 对HBM访问的次数为
往往
(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward Backward的GFLOPs、HBM、Runtime对比(A100 GPU):
GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。
综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。
阅读本文需要了解的符号定义:
Method
Attention
计算流程图如下:
FlashAttention
和
只是部分结果,如下图所示,外循环 j 是横向(特征维 d )移动的,内循环 i 是纵向(序列维 N )移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。
作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。
Theorem 1. FlashAttention的FLOPs为
,除了input和output,额外需要的内存为
。
- The End -