自《Attention is All You Need》一文发布后,基于 Multi-Head Attention 的 Transformer 模型开始流行起来,而 BERT 模型更是将 Transformer 模型的热度推上了又一个高峰。当然,技术的探索是无止境的,改进的工作也相继涌现:有改进预训练任务的,如 XLNET 的 PLM、ALBERT 的 SOP 等;有改进归一化的,如 Post-Norm 向 Pre-Norm 的改变,以及 T5 中去掉了 Layer Norm 里边的 beta 参数等;也有改进模型结构的,如 Transformer-XL 等;有改进训练方式的,如 ALBERT 的参数共享等;...
以上的这些改动,都是在 Attention 外部进行改动的,也就是说它们都默认了 Attention 的合理性,没有对 Attention 本身进行改动,而本文我们则介绍两个新的研究:它们针对 Multi-Head Attention 中可能存在的建模瓶颈,提出了不同的方案来改进 Multi-Heaed Attention。两篇论文都来自 Google,并且做了相当充分的实验,因此结果应该是相当具有说服力的
再小也不能小 key_size
第一个结果来自文章《Low-Rank Bottleneck in Multi-head Attention Models》,它明确地指出了 Multi-Head Attention 里边的表达能力瓶颈,并提出通过增大 key_size 的方法来缓解这个瓶颈
Single-Head Attention
对一个单头注意力机制来说,它的定义如下:
(1)Attention(X)=WvX⋅Softmax[(WkX)T(WqX)dk]=WvX⋅P
其中 Wq∈Rdq×d,Wk∈Rdk×d,Wv∈Rdv×d,因为这是单头注意力机制,所以 dq=dk=dv=d。将上述结果经过一个线性层和 LayerNorm 层得到最终的输出
(2)LN(X Wo⋅Attention(X))
但是论文中提到如果 dq=dk=d≥n,那么给定列满秩矩阵 X∈Rd×n 和 n×n 的正随机矩阵(每一列的和为 1,且矩阵所有元素都为正数)P,一定存在 d×d 维的 Wq,Wk,使得
(3)Softmax[(WkX)T(WqX)dk]=P
成立。但是如果 d<n,则此公式不一定成立
首先证明 d≥n 的情况。因为 X 是列满秩矩阵,所以一定存在其左逆矩阵 X†=(XTX)−1XT∈Rn×d,且 X†X=In,令 Wk=W~kX†,Wq=W~qX†,则
(4)(WkX)T(WqX)=XTWkTWqX=XT(X†)TW~kTW~qX†X=In⋅W~kTW~q⋅In=W~kTW~q≜W~kq
将式 (4) 的最终结果带入式 (1) 得
(5)Softmax[(WkX)T(WqX)dk]=Softmax[Wkq~dk]=exp(W~kqdk)⋅DW~kq−1
其中 DW~kq−1 是一个 n×n 的对角矩阵,并且
(6)(DW~kq−1)ii=∑j=1nexp((Wkq)ji~dk)=(1Texp((W~kq)dk))i
式 (6) 的 1T 是一个全 1 的行向量,其实很好理解,用一个全 1 的行向量右乘一个矩阵,本质就是求和操作
因此,我们现在转而需要证明下式成立
(7)exp(W~kqdk)=P⋅DW~kq
给定 P,为了构建矩阵 W~kq,我们随意选择一个正对角线矩阵(对角线元素大于 0)D0,并且令
(8)W~kq=dk⋅log(P⋅D0)
由于 P 是一个正矩阵(矩阵内的元素都大于 0),所以满足式 (8) 的 W~kq 矩阵总是存在的,接下来我们证明 DW~kq=D0
(9)DW~kq=Diag(1Texp((W~kq)dk))=Diag(1TP⋅D0)=D0
最后一个等式成立,是因为 P 的每一列和为 1。最终,我们结合式 (8) 和式 (9)
(10)exp(W~kqdk)=P⋅D0=P⋅DW~kq
接着我们证明 d<n 的情况。假设 d=1,n=2,则 X∈R1×2,Wq,Wk,Wv∈R1×1,于是
(11)Softmax[(WkX)T(WqX)dk]=Softmax[[1,0]TWkTWq[1,0]dk]=Softmax[[WkWq000]]
这个矩阵很明显不符合我们对矩阵 P 的要求,因为它的第二列元素无法做到不相等,或者说此时 P 的秩很低
Multi-Head Attention
接着我们简单回顾一下 Multi-Head Attenion,首先将 Single-Head Attention 中的一些变量重新进行定义
WkX≜KWqX≜QWvX≜V
于是则有
(12)Attention(Q,K,V)=Softmax(QKTdk)V
其中 Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv。而 Multi-Head Attention 就是将 Q,K,V 分别用 h 个不同的投影矩阵投影 h 次,然后分别做 h 次 Single-Head Attention,最后把结果拼接起来,即
(13)Q(1)=QWQ(1),K(1)=KWK(1),V(1)=VWV(1),O(1)=Attention(Q(1),K(1),V(1))Q(2)=QWQ(2),K(2)=KWK(2),V(2)=VWV(2),O(2)=Attention(Q(2),K(2),V(2))⋮Q(h)=QWQ(h),K(h)=KWK(h),V(h)=VWV(h),O(h)=Attention(Q(h),K(h),V(h))O=[O(1),O(2),…,O(h)]
Attention 里有个瓶颈
在实际使用中,Q,K,V 一般具有相同的特征维度 dk=dv=d(hidden_size),比如 BERT-base 里边是 768;h 一般选择 12、16、24 等,比如 BERT-base 里边是 12;确定了 d,h 之后,通常的选择是让投影矩阵 W∈Rd×dh,也就是说,每个 Attention-Head 里边,是将原始的 d 维投影到 dh 维,然后再进行 Attention 运算,输出也是 dh 维,最后把 h 个 dh 维的结果拼接起来,得到一个 d 维的输出。这里的 dh我们通常称为 head_size
在 Attention 中,关键的一步是
(14)P=Softmax(QKTdk)
在前面我们已经证明了,如果单个头的维度小于句子长度 n,得到的 P 并不好。那么这里单个头的维度是否小于 n 呢?很明显是的,就以 BERT-base 为例,dh=64≪n
不妨试试增大 key_size?
那么,解决办法是什么呢?直接的想法是让 dh 增大,所以要不就是减少 head 的数目 h,要不就是增大 hidden_size 的大小 d。但是更多的 Attention Head 本身也能增强模型的表达能力,所以为了缓解低秩瓶颈而减少 h 的做法可能得不偿失;如果增加 d 的话,那自然是能够增强模型整体表达能力的,但整个模型的规模与计算量也会剧增,似乎也不是一个好选择
难道没有其他办法了吗?有!当我们用投影矩阵将 Q,K,V 都投影到低维时,前面都是将它们投影到 dh维,但其实它们的维度不一定要相等,而只需要保证 Q,K 的维度相等就行了(因为要做内积),为了区别,我们通常称 Q,K 的维度为 key_size,V 的维度才叫 head_size,改变 key_size 的大小而不改变 head_size 的话,也不影响模型的 hidden_size
所以,这篇论文提出来的解决方法就是增大模型的 key_size,它能增加 Attention 的表达能力,并且不改变模型整体的 hidden_size,计算量上也只是稍微增加了一点
事实上原论文考虑的是同时增大 key_size 和 head_size,Multi-Head Attention 的输出拼接之后再用一个线性变换降维,但实际上只增大 key_size 也是有效果的 此外,如果同时增大 key_size 和 head_size 会导致计算量和显存明显增加,而只增大 key_size 的话,增加的资源消耗就小很多了
实验结果
增加 key_size 这个想法很简单,也很容易实现,但是否真的有效呢?我们来看看原论文的实验结果,其实验都是以 BERT 为 baseline 的,实验结果图表很多,推荐大家直接看原论文,这里只分享比较有代表性的一个
保持一个较大的 key_size,能使得模型在同样参数规模的情况下表现更优异
其中 dp=dh。结果显示,如果固定一个比较大的 key_size(比如 128),那么我们可以调整模型的 hidden_size 和 head 数,使得参数量可以跟原始的 BERT 设计一致,但是效果更优!所以,增加 key_size 确实是有意义的,哪怕将总体参数量重新调整到原来的一样大,也能一定程度上提升模型的效果。这无疑对我们设计新的 Transformer 模型(尤其是小规模的模型)有重要的指导作用
再缺不能缺 Talking
对 Multi-Head Attention 改进的第二个结果来自论文《Talking-Heads Attention》,这篇论文虽然没有显式地指出它跟前一篇论文的联系,但笔者认为它们事实上在解决同一个问题,只不过思路不一样:它指出当前的 Multi-Head Attention 每个 head 的运算是相互孤立的,而通过将它们联系(Talking)起来,则可以得到更强的 Attention 设计,即标题的 "Talking-Heads Attention"
从单一分布到混合分布
在前一篇论文里边,我们提到了低秩瓶颈,也就是由于 key_size 太小,所以 (Q(i)K(i))T 表达能力不足。为了缓解这个问题,除了增大 key_size 之外,还有没有其他方法呢?有,比如这篇文论使用的混合分布思路
所谓混合分布,就是多个简单分布的叠加(比如加权平均),它能极大的增强原分布的表达能力。典型的例子是高斯混合模型:我们知道高斯分布只是一个常见的简单分布,但多个高斯分布叠加而成的高斯混合分布(也叫高斯混合模型,GMM)就是一个更强的分布,理论上来说,只要叠加的高斯分布足够多,高斯混合分布能逼近任意概率分布。这个例子告诉我们,想要增加 Attention 中分布的表达能力,又不想增加 key_size,那么可以考虑叠加多个低秩分布
那么 "多个" 低秩分布哪里来呢?不是有 Multi-Head 嘛,每个 head 都带有一个低秩分布,就直接用它们叠加就行了,这就是 Talking-Heads Attention。具体来说,它的形式是:
(15)J^(1)=Q(1)K(1)T,J^(2)=Q(2)K(2)T,⋯,J^(h)=Q(h)K(h)T(J(1)J(2)⋮J(h))=(λ11λ12⋯λ1hλ21λ22⋯λ2h⋮⋮⋱⋮λh1λh2⋯λhh)(J^(1)J^(2)⋮J^(h))P(1)=softmax(J(1)),P(2)=softmax(J(2)),…,P(h)=softmax(J(h))O(1)=P(1)V(1),O(2)=P(2)V(2),,⋯,O(h)=P(h)V(h)O=[O(1),O(2),…,O(h)]
写起来很复杂,事实上很简单,就是在 QKT 之后、Softmax 之前,用一个参数矩阵 λ 将各个 QKT 的结果叠加一下而已。这样就把原本是孤立的各个 Attention Head 联系了起来,即做了一个简单的 Talking
对上述公式做两点补充说明:
- 简单起见,上述公式中笔者省去了缩放因子 dk,如有需要,读者自行补充上去即可
- 更一般的 Talking-Heads Attention 允许在 J=λJ^ 这一步进行升维,即叠加出多于 h 个混合分布,然后再用另一个参数矩阵降维,但这并不是特别重要的改进,所以不做主要介绍
实验结果
是不是真的有效,当然还是得靠实验结果来说话。这篇论文的实验阵容可谓空前强大,它同时包含了 BERT、ALBERT、T5 为 baseline 的实验结果!众所周知,BERT、ALBERT、T5 均是某个时间段的 NLP 最优模型,尤其是 T5 还是处在 superglue 的榜首,并且远超出第二名很多,而这个 Talking-Heads Attention 则几乎是把它们的辉煌战绩又刷到了一个新高度!
还是那句话,具体的实验结果大家自己看论文,这里展示一个比较经典的结果:
结果显示,使用 Talking-Head Attention 情况下,保持 hidden_size 不变,head 数目越大(相应地 key_size 和 head_size 都越小),效果越好。这看起来跟前一篇增大 key_size 的结论矛盾,但是事实上这正说明了混合分布对分布拟合能力具有明显的提升作用,能将 key_size 缩小时本身变弱的单一分布,叠加成拟合能力更强大的分布。当然,这不能说明直接设 key_size=1 就好了,因为 key_size=1 时计算量会远远大于原始的 BERT-base,应用时需要根据实际情况平衡效果和计算量
上述表格只是原论文实验的冰山一角,这里再放出一个实验表格,让大家感受感受它的实验阵容:
几乎每个任务、每个超参数组合都做了实验,并给出实验结果。如此强大的实验阵容,基本上也就只有 Google 能搞出来了
References自《Attention is All You Need》一文发布后,基于 Multi-Head Attention 的 Transformer 模型开始流行起来,而 BERT 模型更是将 Transformer 模型的热度推上了又一个高峰。当然,技术的探索是无止境的,改进的工作也相继涌现:有改进预训练任务的,如 XLNET 的 PLM、ALBERT 的 SOP 等;有改进归一化的,如 Post-Norm 向 Pre-Norm 的改变,以及 T5 中去掉了 Layer Norm 里边的 beta 参数等;也有改进模型结构的,如 Transformer-XL 等;有改进训练方式的,如 ALBERT 的参数共享等;...
以上的这些改动,都是在 Attention 外部进行改动的,也就是说它们都默认了 Attention 的合理性,没有对 Attention 本身进行改动,而本文我们则介绍两个新的研究:它们针对 Multi-Head Attention 中可能存在的建模瓶颈,提出了不同的方案来改进 Multi-Heaed Attention。两篇论文都来自 Google,并且做了相当充分的实验,因此结果应该是相当具有说服力的
再小也不能小 key_size
第一个结果来自文章《Low-Rank Bottleneck in Multi-head Attention Models》,它明确地指出了 Multi-Head Attention 里边的表达能力瓶颈,并提出通过增大 key_size 的方法来缓解这个瓶颈
Single-Head Attention
对一个单头注意力机制来说,它的定义如下:
(1)Attention(X)=WvX⋅Softmax[(WkX)T(WqX)dk]=WvX⋅P
其中 Wq∈Rdq×d,Wk∈Rdk×d,Wv∈Rdv×d,因为这是单头注意力机制,所以 dq=dk=dv=d。将上述结果经过一个线性层和 LayerNorm 层得到最终的输出
(2)LN(X Wo⋅Attention(X))
但是论文中提到如果 dq=dk=d≥n,那么给定列满秩矩阵 X∈Rd×n 和 n×n 的正随机矩阵(每一列的和为 1,且矩阵所有元素都为正数)P,一定存在 d×d 维的 Wq,Wk,使得
(3)Softmax[(WkX)T(WqX)dk]=P
成立。但是如果 d<n,则此公式不一定成立
首先证明 d≥n 的情况。因为 X 是列满秩矩阵,所以一定存在其左逆矩阵 X†=(XTX)−1XT∈Rn×d,且 X†X=In,令 Wk=W~kX†,Wq=W~qX†,则
(4)(WkX)T(WqX)=XTWkTWqX=XT(X†)TW~kTW~qX†X=In⋅W~kTW~q⋅In=W~kTW~q≜W~kq
将式 (4) 的最终结果带入式 (1) 得
(5)Softmax[(WkX)T(WqX)dk]=Softmax[Wkq~dk]=exp(W~kqdk)⋅DW~kq−1
其中 DW~kq−1 是一个 n×n 的对角矩阵,并且
(6)(DW~kq−1)ii=∑j=1nexp((Wkq)ji~dk)=(1Texp((W~kq)dk))i
式 (6) 的 1T 是一个全 1 的行向量,其实很好理解,用一个全 1 的行向量右乘一个矩阵,本质就是求和操作
因此,我们现在转而需要证明下式成立
(7)exp(W~kqdk)=P⋅DW~kq
给定 P,为了构建矩阵 W~kq,我们随意选择一个正对角线矩阵(对角线元素大于 0)D0,并且令
(8)W~kq=dk⋅log(P⋅D0)
由于 P 是一个正矩阵(矩阵内的元素都大于 0),所以满足式 (8) 的 W~kq 矩阵总是存在的,接下来我们证明 DW~kq=D0
(9)DW~kq=Diag(1Texp((W~kq)dk))=Diag(1TP⋅D0)=D0
最后一个等式成立,是因为 P 的每一列和为 1。最终,我们结合式 (8) 和式 (9)
(10)exp(W~kqdk)=P⋅D0=P⋅DW~kq
接着我们证明 d<n 的情况。假设 d=1,n=2,则 X∈R1×2,Wq,Wk,Wv∈R1×1,于是
(11)Softmax[(WkX)T(WqX)dk]=Softmax[[1,0]TWkTWq[1,0]dk]=Softmax[[WkWq000]]
这个矩阵很明显不符合我们对矩阵 P 的要求,因为它的第二列元素无法做到不相等,或者说此时 P 的秩很低
Multi-Head Attention
接着我们简单回顾一下 Multi-Head Attenion,首先将 Single-Head Attention 中的一些变量重新进行定义
WkX≜KWqX≜QWvX≜V
于是则有
(12)Attention(Q,K,V)=Softmax(QKTdk)V
其中 Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv。而 Multi-Head Attention 就是将 Q,K,V 分别用 h 个不同的投影矩阵投影 h 次,然后分别做 h 次 Single-Head Attention,最后把结果拼接起来,即
(13)Q(1)=QWQ(1),K(1)=KWK(1),V(1)=VWV(1),O(1)=Attention(Q(1),K(1),V(1))Q(2)=QWQ(2),K(2)=KWK(2),V(2)=VWV(2),O(2)=Attention(Q(2),K(2),V(2))⋮Q(h)=QWQ(h),K(h)=KWK(h),V(h)=VWV(h),O(h)=Attention(Q(h),K(h),V(h))O=[O(1),O(2),…,O(h)]
Attention 里有个瓶颈
在实际使用中,Q,K,V 一般具有相同的特征维度 dk=dv=d(hidden_size),比如 BERT-base 里边是 768;h 一般选择 12、16、24 等,比如 BERT-base 里边是 12;确定了 d,h 之后,通常的选择是让投影矩阵 W∈Rd×dh,也就是说,每个 Attention-Head 里边,是将原始的 d 维投影到 dh 维,然后再进行 Attention 运算,输出也是 dh 维,最后把 h 个 dh 维的结果拼接起来,得到一个 d 维的输出。这里的 dh我们通常称为 head_size
在 Attention 中,关键的一步是
(14)P=Softmax(QKTdk)
在前面我们已经证明了,如果单个头的维度小于句子长度 n,得到的 P 并不好。那么这里单个头的维度是否小于 n 呢?很明显是的,就以 BERT-base 为例,dh=64≪n
不妨试试增大 key_size?
那么,解决办法是什么呢?直接的想法是让 dh 增大,所以要不就是减少 head 的数目 h,要不就是增大 hidden_size 的大小 d。但是更多的 Attention Head 本身也能增强模型的表达能力,所以为了缓解低秩瓶颈而减少 h 的做法可能得不偿失;如果增加 d 的话,那自然是能够增强模型整体表达能力的,但整个模型的规模与计算量也会剧增,似乎也不是一个好选择
难道没有其他办法了吗?有!当我们用投影矩阵将 Q,K,V 都投影到低维时,前面都是将它们投影到 dh维,但其实它们的维度不一定要相等,而只需要保证 Q,K 的维度相等就行了(因为要做内积),为了区别,我们通常称 Q,K 的维度为 key_size,V 的维度才叫 head_size,改变 key_size 的大小而不改变 head_size 的话,也不影响模型的 hidden_size
所以,这篇论文提出来的解决方法就是增大模型的 key_size,它能增加 Attention 的表达能力,并且不改变模型整体的 hidden_size,计算量上也只是稍微增加了一点
事实上原论文考虑的是同时增大 key_size 和 head_size,Multi-Head Attention 的输出拼接之后再用一个线性变换降维,但实际上只增大 key_size 也是有效果的 此外,如果同时增大 key_size 和 head_size 会导致计算量和显存明显增加,而只增大 key_size 的话,增加的资源消耗就小很多了
实验结果
增加 key_size 这个想法很简单,也很容易实现,但是否真的有效呢?我们来看看原论文的实验结果,其实验都是以 BERT 为 baseline 的,实验结果图表很多,推荐大家直接看原论文,这里只分享比较有代表性的一个
保持一个较大的 key_size,能使得模型在同样参数规模的情况下表现更优异
其中 dp=dh。结果显示,如果固定一个比较大的 key_size(比如 128),那么我们可以调整模型的 hidden_size 和 head 数,使得参数量可以跟原始的 BERT 设计一致,但是效果更优!所以,增加 key_size 确实是有意义的,哪怕将总体参数量重新调整到原来的一样大,也能一定程度上提升模型的效果。这无疑对我们设计新的 Transformer 模型(尤其是小规模的模型)有重要的指导作用
再缺不能缺 Talking
对 Multi-Head Attention 改进的第二个结果来自论文《Talking-Heads Attention》,这篇论文虽然没有显式地指出它跟前一篇论文的联系,但笔者认为它们事实上在解决同一个问题,只不过思路不一样:它指出当前的 Multi-Head Attention 每个 head 的运算是相互孤立的,而通过将它们联系(Talking)起来,则可以得到更强的 Attention 设计,即标题的 "Talking-Heads Attention"
从单一分布到混合分布
在前一篇论文里边,我们提到了低秩瓶颈,也就是由于 key_size 太小,所以 (Q(i)K(i))T 表达能力不足。为了缓解这个问题,除了增大 key_size 之外,还有没有其他方法呢?有,比如这篇文论使用的混合分布思路
所谓混合分布,就是多个简单分布的叠加(比如加权平均),它能极大的增强原分布的表达能力。典型的例子是高斯混合模型:我们知道高斯分布只是一个常见的简单分布,但多个高斯分布叠加而成的高斯混合分布(也叫高斯混合模型,GMM)就是一个更强的分布,理论上来说,只要叠加的高斯分布足够多,高斯混合分布能逼近任意概率分布。这个例子告诉我们,想要增加 Attention 中分布的表达能力,又不想增加 key_size,那么可以考虑叠加多个低秩分布
那么 "多个" 低秩分布哪里来呢?不是有 Multi-Head 嘛,每个 head 都带有一个低秩分布,就直接用它们叠加就行了,这就是 Talking-Heads Attention。具体来说,它的形式是:
(15)J^(1)=Q(1)K(1)T,J^(2)=Q(2)K(2)T,⋯,J^(h)=Q(h)K(h)T(J(1)J(2)⋮J(h))=(λ11λ12⋯λ1hλ21λ22⋯λ2h⋮⋮⋱⋮λh1λh2⋯λhh)(J^(1)J^(2)⋮J^(h))P(1)=softmax(J(1)),P(2)=softmax(J(2)),…,P(h)=softmax(J(h))O(1)=P(1)V(1),O(2)=P(2)V(2),,⋯,O(h)=P(h)V(h)O=[O(1),O(2),…,O(h)]
写起来很复杂,事实上很简单,就是在 QKT 之后、Softmax 之前,用一个参数矩阵 λ 将各个 QKT 的结果叠加一下而已。这样就把原本是孤立的各个 Attention Head 联系了起来,即做了一个简单的 Talking
对上述公式做两点补充说明:
- 简单起见,上述公式中笔者省去了缩放因子 dk,如有需要,读者自行补充上去即可
- 更一般的 Talking-Heads Attention 允许在 J=λJ^ 这一步进行升维,即叠加出多于 h 个混合分布,然后再用另一个参数矩阵降维,但这并不是特别重要的改进,所以不做主要介绍
实验结果
是不是真的有效,当然还是得靠实验结果来说话。这篇论文的实验阵容可谓空前强大,它同时包含了 BERT、ALBERT、T5 为 baseline 的实验结果!众所周知,BERT、ALBERT、T5 均是某个时间段的 NLP 最优模型,尤其是 T5 还是处在 superglue 的榜首,并且远超出第二名很多,而这个 Talking-Heads Attention 则几乎是把它们的辉煌战绩又刷到了一个新高度!
还是那句话,具体的实验结果大家自己看论文,这里展示一个比较经典的结果:
结果显示,使用 Talking-Head Attention 情况下,保持 hidden_size 不变,head 数目越大(相应地 key_size 和 head_size 都越小),效果越好。这看起来跟前一篇增大 key_size 的结论矛盾,但是事实上这正说明了混合分布对分布拟合能力具有明显的提升作用,能将 key_size 缩小时本身变弱的单一分布,叠加成拟合能力更强大的分布。当然,这不能说明直接设 key_size=1 就好了,因为 key_size=1 时计算量会远远大于原始的 BERT-base,应用时需要根据实际情况平衡效果和计算量
上述表格只是原论文实验的冰山一角,这里再放出一个实验表格,让大家感受感受它的实验阵容:
几乎每个任务、每个超参数组合都做了实验,并给出实验结果。如此强大的实验阵容,基本上也就只有 Google 能搞出来了