一步一步图解Self-Attention

2022-04-28 19:11:21 浏览数 (1)

Self Attention

自注意力机制(Self-attention Mechanism)是Google机器翻译团队2017年在《Attention is all you need》论文中提出的,它抛弃了传统的通过RNN来做Seq2Seq任务的做法,对神经网络训练的并行化更加友好。

本文通过实例一步一步的拆解Self Attention的每个步骤,帮助我们更好的理解Self Attention运行过程中发生了什么。阅读完本文后,你应该能够从头编写Self Attention模块了。

Self Attention运行过程

1、Prepare Inputs

Prepare inputs

我们假设输入是3个四维的向量。在实际应用中,这些向量都是通过Embeding的过程生成的。

代码语言:javascript复制
Input 1: [1, 0, 1, 0]
Input 2: [0, 2, 0, 2]
Input 3: [1, 1, 1, 1]

2、Initialise Weights

Self Attention的每个输入必须有三种表达(Representations):Key(下图橙色所示)、Query(下图红色所示)、Value(如下图紫色所示)。

Derive key representations from each input

由于输入是三个四维(大小为3x4)矩阵,所以Key、Query、Value的权重是大小为4x3的矩阵。我们初始化Key、Value、Query的权重矩阵如下:

Key的Weights:

代码语言:javascript复制
[[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]]

Query的Weights:

代码语言:javascript复制
[[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]]

Value的Weights:

代码语言:javascript复制
[[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]]

注意:

在实际的神经网络的训练中,我们往往需要根据经验选择适当的随机分布(如高斯分布、Xavie初始化和Kaiming分布等)对权重参数进行随机初始化。

深度学习中如何选择合适的初始化权重

3、Derive Key,Query And Value

有了Inputs和Weights之后,我们可以分别计算每个Input对应的Key、Value和Query。

Input 1的Key Representation:

代码语言:javascript复制

               [0, 0, 1]
[1, 0, 1, 0] x [1, 1, 0] = [0, 1, 1]
               [0, 1, 0]
               [1, 1, 0]

Input 2的Key Representation:

代码语言:javascript复制

               [0, 0, 1]
[0, 2, 0, 2] x [1, 1, 0] = [4, 4, 0]
               [0, 1, 0]
               [1, 1, 0]

Input 3的Key Representation:

代码语言:javascript复制

               [0, 0, 1]
[1, 1, 1, 1] x [1, 1, 0] = [2, 3, 1]
               [0, 1, 0]
               [1, 1, 0]

采用矩阵乘法的方式来合并上述三个操作,得到Key Representation:

代码语言:javascript复制

               [0, 0, 1]
[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]

Derive key representations from each input

使用同样的方式计算Value Representations:

代码语言:javascript复制
               [0, 2, 0]
[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3]
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]

Derive value representations from each input

使用相同的方式计算Query Representation:

代码语言:javascript复制
               [1, 0, 1]
[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]

Derive query representations from each input

注意:

在实际的项目中,除了Weight之外,我们还需要考虑Bias的使用。

4、Calculate Attention Scores For Input 1

Calculating attention scores (Blue) from query

为了计算Attention Scores,我们对Input 1、Input2和Input 3的Key乘以Input 1的Query,得到三个Attention Score(如上图蓝色所示)。

代码语言:javascript复制

            [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
            [1, 0, 1]

同样的,使用Input 2和Input 3的Query乘以Input 1、Input 2和Input 3的Key,得到Attention Score。

注意:

这里我们使用点乘(dot product)计算Attention Score,这只是计算Attention Score的方式之一,其它的计算方式(比如Additive、Concat等)也可以用来计算Attention Score。

5、Calculate Softmax

Softmax the attention scores (blue)

对计算出的Attention Score使用Softmax操作(如上图蓝色所示):

代码语言:javascript复制
softmax([2, 4, 4]) = [0.0, 0.5, 0.5]

6、Multiply Scores With Values

Derive weighted value representation (yellow) from multiply value (purple) and score (blue)

将每个Input(如上图蓝色所示)乘以每个Input的Value(如上图紫色所示),生成三个对齐的向量。(如上图黄色所示)。

代码语言:javascript复制
1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]

7、Sum Weighted Values To Get Output 1

Sum all weighted values (yellow) to get Output 1 (dark green)

将黄色部分的向量进行加和处理。

代码语言:javascript复制
  [0.0, 0.0, 0.0]
  [1.0, 4.0, 0.0]
  [1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]

8、Repeat Steps 4–7 For Input 2 & Input 3

计算完Output 1之后,我们重复步骤4到步骤7计算Output 2和Output 3。

Repeat previous steps for Input 2 & Input 3

代码实现

在PyTorch中实现上述代码如下。运行的环境为:Python≥3.6 and PyTorch 1.3.1。

Step 1: Prepare Inputs

代码语言:javascript复制
import torch

x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)

Step 2: Initialise Weights

代码语言:javascript复制
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

Step 3: Derive Key, Query And Value

代码语言:javascript复制
keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print(keys)
# tensor([[0., 1., 1.],
#         [4., 4., 0.],
#         [2., 3., 1.]])

print(querys)
# tensor([[1., 0., 2.],
#         [2., 2., 2.],
#         [2., 1., 3.]])

print(values)
# tensor([[1., 2., 3.],
#         [2., 8., 0.],
#         [2., 6., 3.]])

Step 4: Calculate Attention Scores

代码语言:javascript复制
attn_scores = querys @ keys.T

# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1
#         [ 4., 16., 12.],  # attention scores from Query 2
#         [ 4., 12., 10.]]) # attention scores from Query 3

Step 5: Calculate Softmax

代码语言:javascript复制
from torch.nn.functional import softmax

attn_scores_softmax = softmax(attn_scores, dim=-1)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
#         [6.0337e-06, 9.8201e-01, 1.7986e-02],
#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)

Step 6: Multiply Scores With Values

代码语言:javascript复制
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]

# tensor([[[0.0000, 0.0000, 0.0000],
#          [0.0000, 0.0000, 0.0000],
#          [0.0000, 0.0000, 0.0000]],
#
#         [[1.0000, 4.0000, 0.0000],
#          [2.0000, 8.0000, 0.0000],
#          [1.8000, 7.2000, 0.0000]],
#
#         [[1.0000, 3.0000, 1.5000],
#          [0.0000, 0.0000, 0.0000],
#          [0.2000, 0.6000, 0.3000]]])

Step 7: Sum Weighted Values

代码语言:javascript复制
outputs = weighted_values.sum(dim=0)

# tensor([[2.0000, 7.0000, 1.5000],  # Output 1
#         [2.0000, 8.0000, 0.0000],  # Output 2
#         [2.0000, 7.8000, 0.3000]]) # Output 3

0 人点赞