1. 背景
上一篇博客讲了Transformers里面的self-attention,在NLP领域中其实attentionseq2seq的时候就有广泛应用了。这篇文章主要总结一下从从RNN LSTM GRU seq2seq 到attention的种类及应用,方便大家理解整体发展与attention机制。
2. RNN
RNN 基本的模型如上图所示,每个神经元接受的输入包括:前一个神经元的隐藏层状态 h (用于记忆) 和当前的输入 x (当前信息)。神经元得到输入之后,会计算出新的隐藏状态 h 和输出 y,然后再传递到下一个神经元。因为隐藏状态 h 的存在,使得 RNN 具有一定的记忆功能。
针对不同任务,通常要对 RNN 模型结构进行少量的调整,根据输入和输出的数量,分为三种比较常见的结构:
N vs N、1 vs N、N vs 1。
1.1 RNN 结构1: N vs N
上图是RNN 模型的一种 N vs N 结构,包含 N 个输入 x1, x2, ..., xN,和 N 个输出 y1, y2, ..., yN。N vs N 的结构中,输入和输出序列的长度是相等的,通常适合用于以下任务:
- 词性标注
- 训练语言模型,使用之前的词预测下一个词等
1.2 RNN 结构2: 1 vs N
在 1 vs N 结构中,我们只有一个输入 x,和 N 个输出 y1, y2, ..., yN。可以有两种方式使用 1 vs N:
第一种只将输入 x 传入第一个 RNN 神经元
第二种是将输入 x 传入所有的 RNN 神经元
1 vs N 结构适合用于以下任务:
- 图像生成文字,输入 x 就是一张图片,输出就是一段图片的描述文字。
- 根据音乐类别,生成对应的音乐。
- 根据小说类别,生成相应的小说。
1.3 RNN 结构3: N vs 1
在 N vs 1 结构中,我们有 N 个输入 x1, x2, ..., xN,和一个输出 y。N vs 1 结构适合用于以下任务:
- 序列分类任务,一段语音、一段文字的类别,句子的情感分析。
1.4 LSTM
我们可以通过 LSTM 比较好地缓解 RNN 梯度消失的问题。
而 LSTM 的神经元在此基础上还输入了一个 cell 状态 ct-1, cell 状态 c 和 RNN 中的隐藏状态 h 相似,都保存了历史的信息,从 ct-2 ~ ct-1 ~ ct。在 LSTM 中 c 与 RNN 中的 h 扮演的角色很像,都是保存历史状态信息,而在 LSTM 中的 h 更多地是保存上一时刻的输出信息。
1.4.1 遗忘门
上图中红色框中的是 LSTM 遗忘门部分,用来判断 cell 状态 ct-1 中哪些信息应该删除。其中 σ 表示激活函数 sigmoid。输入的 ht-1 和 xt 经过 sigmoid 激活函数之后得到 ft,ft 中每一个值的范围都是 [0, 1]。ft 中的值越接近 1,表示 cell 状态 ct-1 中对应位置的值更应该记住;ft 中的值越接近 0,表示 cell 状态 ct-1 中对应位置的值更应该忘记。将 ft 与 ct-1 按位相乘 (ElementWise 相乘),即可以得到遗忘无用信息之后的 c’t-1。
1.4.2 输入门
上图中红色框中的是 LSTM 输入门部分,用来判断哪些新的信息应该加入到 cell 状态 c‘t-1 中。其中 σ 表示激活函数 sigmoid。输入的 ht-1 和 xt 经过 tanh 激活函数可以得到新的输入信息 (图中带波浪线的 Ct),但是这些新信息并不全是有用的,因此需要使用 ht-1 和 xt 经过 sigmoid 函数得到 it, it 表示哪些新信息是有用的。两向量相乘后的结果加到 c’t-1 中,即得到 t 时刻的 cell 状态 ct。
1.4.3 输出门
上图中红色框中的是 LSTM 输出门部分,用来判断应该输出哪些信息到 ht 中。cell 状态 ct 经过 tanh 函数得到可以输出的信息,然后 ht-1 和 xt 经过 sigmoid 函数得到一个向量 ot,ot 的每一维的范围都是 [0, 1],表示哪些位置的输出应该去掉,哪些应该保留。两向量相乘后的结果就是最终的 ht。
1.5 GRU
GRU 是 LSTM 的一种变种,结构比 LSTM 简单一点。LSTM有三个门 (遗忘门 forget,输入门 input,输出门output),而 GRU 只有两个门 (更新门 update,重置门 reset)。另外,GRU 没有 LSTM 中的 cell 状态 c。
图中的 zt 和 rt 分别表示更新门 (红色) 和重置门 (蓝色)。重置门 rt 控制着前一状态的信息 ht-1 传入候选状态 (图中带波浪线的ht) 的比例,重置门 rt 的值越小,则与 ht-1 的乘积越小,ht-1 的信息添加到候选状态越少。更新门用于控制前一状态的信息 ht-1 有多少保留到新状态 ht 中,当 (1-zt) 越大,保留的信息越多。
2. seq2seq
2.1 seq2seq结构
RNN 的输入和输出个数都有一定的限制,但实际中很多任务的序列的长度是不固定的,例如机器翻译中,源语言、目标语言的句子长度不一样;对话系统中,问句和答案的句子长度不一样。
eq2Seq 是一种重要的 RNN 模型,也称为 Encoder-Decoder 模型,可以理解为一种 N×M 的模型。模型包含两个部分:Encoder 用于编码序列的信息,将任意长度的序列信息编码到一个向量 c 里。而 Decoder 是解码器,解码器得到上下文信息向量 c 之后可以将信息解码,并输出为序列。Seq2Seq 模型结构有很多种,下面是几种比较常见的:
2.1.1 seq2seq结构1
2.1.2 seq2seq结构2
2.1.3 seq2seq结构3
2.2 编码器Encoder
这三种 Seq2Seq 模型的主要区别在于 Decoder,他们的 Encoder 都是一样的。下图是 Encoder 部分,Encoder 的 RNN 接受输入 x,最终输出一个编码所有信息的上下文向量 c,中间的神经元没有输出。Decoder 主要传入的是上下文向量 c,然后解码出需要的信息。
从公式可以看到,c 可以直接使用最后一个神经元的隐藏状态 hN 表示;也可以在最后一个神经元的隐藏状态上进行某种变换 hN 而得到,q 函数表示某种变换;也可以使用所有神经元的隐藏状态 h1, h2, ..., hN 计算得到。得到上下文向量 c 之后,需要传递到 Decoder。
2.3 解码器Decoder
2.3.1 Decoder #1
第一种 Decoder 结构比较简单,将上下文向量 c 当成是 RNN 的初始隐藏状态,输入到 RNN 中,后续只接受上一个神经元的隐藏层状态 h' 而不接收其他的输入 x。第一种 Decoder 结构的隐藏层及输出的计算公式:
2.3.2 Decoder #2
第二种 Decoder 结构有了自己的初始隐藏层状态 h'0,不再把上下文向量 c 当成是 RNN 的初始隐藏状态,而是当成 RNN 每一个神经元的输入。可以看到在 Decoder 的每一个神经元都拥有相同的输入 c,这种 Decoder 的隐藏层及输出计算公式:
2.3.3 Decoder #3
第三种 Decoder 结构和第二种类似,但是在输入的部分多了上一个神经元的输出 y'。即每一个神经元的输入包括:上一个神经元的隐藏层向量 h',上一个神经元的输出 y',当前的输入 c (Encoder 编码的上下文向量)。对于第一个神经元的输入 y'0,通常是句子其实标志位的 embedding 向量。第三种 Decoder 的隐藏层及输出计算公式:
2.4 beam search
beam search 方法不用于训练的过程,而是用在测试的。在每一个神经元中,我们都选取当前输出概率值最大的 top k 个输出传递到下一个神经元。下一个神经元分别用这 k 个输出,计算出 L 个单词的概率 (L 为词汇表大小),然后在 kL 个结果中得到 top k 个最大的输出,重复这一步骤。
2.5 teacher forcing
Teacher Forcing 用于训练阶段,主要针对上面第三种 Decoder 模型来说的,第三种 Decoder 模型神经元的输入包括了上一个神经元的输出 y'。如果上一个神经元的输出是错误的,则下一个神经元的输出也很容易错误,导致错误会一直传递下去。
而 Teacher Forcing 可以在一定程度上缓解上面的问题,在训练 Seq2Seq 模型时,Decoder 的每一个神经元并非一定使用上一个神经元的输出,而是有一定的比例采用正确的序列作为输入。
3. Attention 注意力机制
Attention的思想如同它的名字一样,就是“注意力”,在预测结果时把注意力放在不同的特征上。例如翻译 "I have a cat",翻译到 "我" 时,要将注意力放在源句子的 "I" 上,翻译到 "猫" 时要将注意力放在源句子的 "cat" 上。
3.1 attention的计算
通常我们会将输入分为query(Q), key(K), value(V)三种:
- 先用Q和K计算权重 a ,会用softmax对权重归一化: a=softmax(f(QK))
- QK的具体运算f有多种方法,常见的有加性attention和乘性attention等:
- 加性attention: f(Q,K)=tanh(W1Q W2K)
- 乘性attention: f(Q,K)=QKT
- 缩放点积attention: f(Q,K)=QKT√d
- 双线性点积attention: f(Q,K)=QWKT
- QK的具体运算f有多种方法,常见的有加性attention和乘性attention等:
- 再用权重对结果加权: out=sum a_i*v_i
这种机制其实做的是寻址(addressing),也就是模仿中央处理器与存储交互的方式将存储的内容读出来。如上图所示:给定一个和任务相关的查询Query向量 q,通过计算与Key的注意力分布并附加在Value上,从而计算Attention Value,这个过程实际上是Attention机制缓解神经网络模型复杂度的体现:不需要将所有的N个输入信息都输入到神经网络进行计算,只需要从X中选择一些和任务相关的信息输入给神经网络。
3.2 attention在seq2seq
其中h^i 是编码器Encoder每个step的输出, z^j 是解码器Decoder每个step的输出,计算步骤是这样的:
- 先对输入进行编码,得到 [h^1,h^2,h^3,h^4]
- 开始解码了,先用固定的start token也就是z^0 最为Q,去和每个 h^i (同时作为K和V)去计算attention,得到加权的 c^0
- 用 c^0 作为解码的RNN输入(同时还有上一步的 z^0 ),得到 z^1 并预测出第一个词是machine
- 再继续预测的话,就是用z^1作为Q去求attention:
换一个图可能会好理解一点,可以两个一起看看:
使用了 Attention 后,Decoder 的输入就不是固定的上下文向量c了,而是会根据当前翻译的信息,计算当前的c。
Attention 需要保留 Encoder 每一个神经元的隐藏层向量 h,然后 Decoder 的第 t 个神经元要根据上一个神经元的隐藏层向量 h't-1 计算出当前状态与 Encoder 每一个神经元的相关性 et。et 是一个 N 维的向量 (Encoder 神经元个数为 N),若 et 的第 i 维越大,则说明当前节点与 Encoder 第 i 个神经元的相关性越大。
et 的计算方法有很多种,即相关性系数的计算函数 a 有很多种,如上文所提到。
上面得到相关性向量et 后,需要进行归一化,使用 softmax 归一化。然后用归一化后的系数融合 Encoder 的多个隐藏层向量得到 Decoder 当前神经元的上下文向量ct:
3.3 attention分类
- Soft/Hard Attention
- soft attention:传统attention,可被嵌入到模型中去进行训练并传播梯度
- hard attention:不计算所有输出,依据概率对encoder的输出采样,在反向传播时需采用蒙特卡洛进行梯度估计
- Global/Local Attention
- global attention:传统attention,对所有encoder输出进行计算
- local attention:介于soft和hard之间,会预测一个位置并选取一个窗口进行计算
- Self Attention
- 传统attention是计算Q和K之间的依赖关系,而self attention则分别计算Q和K自身的依赖关系。具体可见上篇博客。
3.4 self attention 与 attention区别于关系
attention和self attention 其具体计算过程是一样的,只是计算对象发生了变化而已。
- attention是source对target的attention
- 比如对于英-中机器翻译来说,Source是英文句子,Target是对应的翻译出的中文句子,Attention机制发生在Target的元素Query和Source中的所有元素之间。简单的讲就是Attention机制中的权重的计算需要Target来参与的,即在Encoder-Decoder model中Attention权值的计算不仅需要Encoder中的隐状态而且还需要Decoder 中的隐状态。
- self attention 是source 对source的attention。
- 例如在Transformer中在计算权重参数时将文字向量转成对应的KQV,只需要在Source处进行对应的矩阵操作,用不到Target中的信息。
3.4.1 self-attention
self attention会给你一个矩阵,告诉你 entity1 和entity2、entity3 ….的关联程度、entity2和entity1、entity3…的关联程度。
它指的不是Target和Source之间的Attention机制,而是Source内部元素之间或者Target内部元素之间发生的Attention机制,也可以理解为Target=Source这种特殊情况下的注意力计算机制。Q=K=V。
3.4.2 attention
Attention机制发生在Target的元素Query和Source中的所有元素之间。
比如entity1,entity2,entity3….,attn会输出[0.1,0.2,0.5,….]这种,告诉你entity3重要些。
Ref
- https://zhuanlan.zhihu.com/p/43493999
- https://www.jianshu.com/p/80436483b13b
- https://www.jianshu.com/p/247a72812aff
- https://www.zhihu.com/question/68482809/answer/597944559
- http://www.sniper97.cn/index.php/note/deep-learning/base/3606/