Bert Pytorch 源码分析:二、注意力层

2023-10-13 09:23:45 浏览数 (1)

代码语言:javascript复制
# 注意力机制的具体模块
# 兼容单头和多头
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

	# QKV 尺寸都是 BS * ML * ES
	# (或者多头情况下是 BS * HC * ML * HS,最后两维之外的维度不重要)
	# 从输入计算 QKV 的过程可以统一处理,不必放到每个头里面
    def forward(self, query, key, value, mask=None, dropout=None):
		# 将每个批量的 Q 和 K.T 做矩阵乘法,再除以√ES,
		# 得到相关性矩阵 S,尺寸为 BS * ML * ML
        scores = torch.matmul(query, key.transpose(-2, -1)) 
                 / math.sqrt(query.size(-1))

		# 如果存在掩码则使用它
		# 将 scores 的 mask == 0 的位置上的元素改为 -1e9
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 将 S 转换到概率空间,同时对其最后一维归一化
        p_attn = F.softmax(scores, dim=-1)

		# 如果存在 dropout 则使用
        if dropout is not None:
            p_attn = dropout(p_attn)

		# 最后将 S 与 V 相乘得到输出
        return torch.matmul(p_attn, value), p_attn
		
# 多头注意力就是包含很多(HC)个头,但是每个头的尺寸(HS)变为原来的 1/HC
# 把 qkv 切成小段分给每个头做运算,将结果拼起来作为整个层的输出
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

	# h 是头数(HC)
	# d_model 是嵌入向量大小(ES)
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
		# 判断 ES 是否能被 HC 整除,以便结果能拼接回去
        assert d_model % h == 0

		# d_k 是每个头的大小 HS = ES // HC
        self.d_k = d_model // h
        self.h = h

		# 创建输入转换为QKV的权重矩阵,Wq, Wk, Wv,尺寸均为 ES * ES
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
		# 输出应该还乘一个权重矩阵,Wo,尺寸也是 ES * ES
        self.output_linear = nn.Linear(d_model, d_model)
		# 创建执行注意力机制的具体模块
        self.attention = Attention()
		# 创建 droput 层
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
		# 获取批量大小(BS)
        batch_size = query.size(0)

       
		'''
        query, key, value = [
			l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
		    for l, x in zip(self.linear_layers, (query, key, value))
		]
		'''
		# 将 QKV 的每个与其相应权重矩阵 Wq, Wk, Wv 相乘
		lq, lk, lv = self.linear_layers
		query, key, value = lq(query), lk(key), lv(value) 
		
		# 然后将他们转型为 BS * ML * HC * HS
		# 也就是将最后一个维度按头部数量分割成小的向量
		query, key, value = [
			x.view(batch_size, -1, self.h, self.d_k)
			for x in (query, key, value)
		]
		
		# 然后交换 1 和 2 维,变成 BS * HC * ML  * HS
		# 这样每个头的 QKV 是内存连续的,便于矩阵相乘
		query, key, value = [
			x.transpose(1, 2)
			for x in (query, key, value)
		]

        # 对每个头应用注意力机制,输出尺寸不变
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 交换 1 和 2 维恢复原状,然后把每个头的输出相连接,尺寸变为 BS * ML * ES
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

		# 执行最后的矩阵相乘
        return self.output_linear(x)

缩写表

  • BS:批量大小,即一批数据中样本大小,训练集和测试集可能不同,那就是TBS和VBS
  • ES:嵌入大小,嵌入向量空间的维数,也是注意力层的隐藏单元数量,GPT 中一般是 768
  • ML:输入序列最大长度,一般是512或者1024,不够需要用<pad>填充
  • HC:头部的数量,需要能够整除ES,因为每个头的输出拼接起来才是层的输出
  • HS:头部大小,等于ES // HC
  • VS:词汇表大小,也就是词的种类数量

尺寸备注

  • 嵌入层的矩阵尺寸应该是VS * ES
  • 注意力层的输入尺寸是BS * ML * ES
  • 输出以及 Q K V 和输入形状相同
  • 每个头的 QKV 尺寸为BS * ML * HS
  • 权重矩阵尺寸为ES * ES
  • 相关矩阵 S 尺寸为BS * ML * ML

0 人点赞