Self Attention
自注意力机制(Self-attention Mechanism)是Google机器翻译团队2017年在《Attention is all you need》论文中提出的,它抛弃了传统的通过RNN来做Seq2Seq任务的做法,对神经网络训练的并行化更加友好。
本文通过实例一步一步的拆解Self Attention的每个步骤,帮助我们更好的理解Self Attention运行过程中发生了什么。阅读完本文后,你应该能够从头编写Self Attention模块了。
Self Attention运行过程
1、Prepare Inputs
Prepare inputs
我们假设输入是3个四维的向量。在实际应用中,这些向量都是通过Embeding的过程生成的。
代码语言:javascript复制Input 1: [1, 0, 1, 0]
Input 2: [0, 2, 0, 2]
Input 3: [1, 1, 1, 1]
2、Initialise Weights
Self Attention的每个输入必须有三种表达(Representations):Key(下图橙色所示)、Query(下图红色所示)、Value(如下图紫色所示)。
Derive key representations from each input
由于输入是三个四维(大小为3x4)矩阵,所以Key、Query、Value的权重是大小为4x3的矩阵。我们初始化Key、Value、Query的权重矩阵如下:
Key的Weights:
代码语言:javascript复制[[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]]
Query的Weights:
代码语言:javascript复制[[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]]
Value的Weights:
代码语言:javascript复制[[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]]
注意:
在实际的神经网络的训练中,我们往往需要根据经验选择适当的随机分布(如高斯分布、Xavie初始化和Kaiming分布等)对权重参数进行随机初始化。
深度学习中如何选择合适的初始化权重
3、Derive Key,Query And Value
有了Inputs和Weights之后,我们可以分别计算每个Input对应的Key、Value和Query。
Input 1的Key Representation:
代码语言:javascript复制
[0, 0, 1]
[1, 0, 1, 0] x [1, 1, 0] = [0, 1, 1]
[0, 1, 0]
[1, 1, 0]
Input 2的Key Representation:
代码语言:javascript复制
[0, 0, 1]
[0, 2, 0, 2] x [1, 1, 0] = [4, 4, 0]
[0, 1, 0]
[1, 1, 0]
Input 3的Key Representation:
代码语言:javascript复制
[0, 0, 1]
[1, 1, 1, 1] x [1, 1, 0] = [2, 3, 1]
[0, 1, 0]
[1, 1, 0]
采用矩阵乘法的方式来合并上述三个操作,得到Key Representation:
代码语言:javascript复制
[0, 0, 1]
[1, 0, 1, 0] [1, 1, 0] [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1] [1, 1, 0] [2, 3, 1]
Derive key representations from each input
使用同样的方式计算Value Representations:
代码语言:javascript复制 [0, 2, 0]
[1, 0, 1, 0] [0, 3, 0] [1, 2, 3]
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1] [1, 1, 0] [2, 6, 3]
Derive value representations from each input
使用相同的方式计算Query Representation:
代码语言:javascript复制 [1, 0, 1]
[1, 0, 1, 0] [1, 0, 0] [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1] [0, 1, 1] [2, 1, 3]
Derive query representations from each input
注意:
在实际的项目中,除了Weight之外,我们还需要考虑Bias的使用。
4、Calculate Attention Scores For Input 1
Calculating attention scores (Blue) from query
为了计算Attention Scores,我们对Input 1、Input2和Input 3的Key乘以Input 1的Query,得到三个Attention Score(如上图蓝色所示)。
代码语言:javascript复制
[0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
[1, 0, 1]
同样的,使用Input 2和Input 3的Query乘以Input 1、Input 2和Input 3的Key,得到Attention Score。
注意:
这里我们使用点乘(dot product)计算Attention Score,这只是计算Attention Score的方式之一,其它的计算方式(比如Additive、Concat等)也可以用来计算Attention Score。
5、Calculate Softmax
Softmax the attention scores (blue)
对计算出的Attention Score使用Softmax操作(如上图蓝色所示):
代码语言:javascript复制softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
6、Multiply Scores With Values
Derive weighted value representation (yellow) from multiply value (purple) and score (blue)
将每个Input(如上图蓝色所示)乘以每个Input的Value(如上图紫色所示),生成三个对齐的向量。(如上图黄色所示)。
代码语言:javascript复制1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]
7、Sum Weighted Values To Get Output 1
Sum all weighted values (yellow) to get Output 1 (dark green)
将黄色部分的向量进行加和处理。
代码语言:javascript复制 [0.0, 0.0, 0.0]
[1.0, 4.0, 0.0]
[1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]
8、Repeat Steps 4–7 For Input 2 & Input 3
计算完Output 1之后,我们重复步骤4到步骤7计算Output 2和Output 3。
Repeat previous steps for Input 2 & Input 3
代码实现
在PyTorch中实现上述代码如下。运行的环境为:Python≥3.6 and PyTorch 1.3.1。
Step 1: Prepare Inputs
代码语言:javascript复制import torch
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x, dtype=torch.float32)
Step 2: Initialise Weights
代码语言:javascript复制w_key = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_query = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_value = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
Step 3: Derive Key, Query And Value
代码语言:javascript复制keys = x @ w_key
querys = x @ w_query
values = x @ w_value
print(keys)
# tensor([[0., 1., 1.],
# [4., 4., 0.],
# [2., 3., 1.]])
print(querys)
# tensor([[1., 0., 2.],
# [2., 2., 2.],
# [2., 1., 3.]])
print(values)
# tensor([[1., 2., 3.],
# [2., 8., 0.],
# [2., 6., 3.]])
Step 4: Calculate Attention Scores
代码语言:javascript复制attn_scores = querys @ keys.T
# tensor([[ 2., 4., 4.], # attention scores from Query 1
# [ 4., 16., 12.], # attention scores from Query 2
# [ 4., 12., 10.]]) # attention scores from Query 3
Step 5: Calculate Softmax
代码语言:javascript复制from torch.nn.functional import softmax
attn_scores_softmax = softmax(attn_scores, dim=-1)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
# [6.0337e-06, 9.8201e-01, 1.7986e-02],
# [2.9539e-04, 8.8054e-01, 1.1917e-01]])
# For readability, approximate the above as follows
attn_scores_softmax = [
[0.0, 0.5, 0.5],
[0.0, 1.0, 0.0],
[0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
Step 6: Multiply Scores With Values
代码语言:javascript复制weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
# tensor([[[0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000]],
#
# [[1.0000, 4.0000, 0.0000],
# [2.0000, 8.0000, 0.0000],
# [1.8000, 7.2000, 0.0000]],
#
# [[1.0000, 3.0000, 1.5000],
# [0.0000, 0.0000, 0.0000],
# [0.2000, 0.6000, 0.3000]]])
Step 7: Sum Weighted Values
代码语言:javascript复制outputs = weighted_values.sum(dim=0)
# tensor([[2.0000, 7.0000, 1.5000], # Output 1
# [2.0000, 8.0000, 0.0000], # Output 2
# [2.0000, 7.8000, 0.3000]]) # Output 3