代码语言:javascript复制
# Bert 编码器模块
# 由一个嵌入层和 NL 个 TF 层组成
class BERT(nn.Module):
"""
BERT model : Bidirectional Encoder Representations from Transformers.
"""
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
"""
:param vocab_size: vocab_size of total words
:param hidden: BERT model hidden size
:param n_layers: numbers of Transformer blocks(layers)
:param attn_heads: number of attention heads
:param dropout: dropout rate
"""
super().__init__()
# 嵌入大小 ES
self.hidden = hidden
# TF 层数 NL
self.n_layers = n_layers
# 头部数量 HC
self.attn_heads = attn_heads
# FFN 层中的隐藏单元数量,记为 FF,一般是 ES 的四倍
self.feed_forward_hidden = hidden * 4
# 嵌入层,嵌入矩阵尺寸 VS * ES
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
# NL 个 TF 层
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
def forward(self, x, segment_info):
# 为`<pad>`(ID = 0)设置掩码
# 尺寸为 BS * 1 * ML * ML,以便与相似性矩阵 S 匹配
# 在每个 BS 的 ML * ML 矩阵中,`<pad>`标记对应的行为 1,其余为零
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# 单词 ID 传入嵌入层得到词向量
x = self.embedding(x, segment_info)
# 依次传入每个 TF 层,得到编码器输出
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
return x
# 解码器结构根据具体任务而定
# 任务一般有三种:(1)序列分类,(2)标记分类,(3)序列生成
# 但一般都是全连接的
# 用于下个句子判断的解码器
# 序列分类任务,输入两个句子,输出一个标签,1表示是相邻句子,0表示不是
class NextSentencePrediction(nn.Module):
"""
2-class classification model : is_next, is_not_next
"""
def __init__(self, hidden):
"""
:param hidden: BERT model output size
"""
super().__init__()
# 将向量压缩到两维, 尺寸为 ES * 2
self.linear = nn.Linear(hidden, 2)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
# 输入 -> 取第一个向量 -> LL -> softmax -> 输出
# 输出相邻句子和非相邻句子的概率
return self.softmax(self.linear(x[:, 0]))
# 用于完型填空的解码器
# 序列生成任务,输入是带有`<mask>`的句子,输出是完整句子
class MaskedLanguageModel(nn.Module):
"""
predicting origin token from masked input sequence
n-class classification problem, n-class = vocab_size
"""
def __init__(self, hidden, vocab_size):
"""
:param hidden: output size of BERT model
:param vocab_size: total vocab size
"""
super().__init__()
# 将输入压缩到词汇表大小
self.linear = nn.Linear(hidden, vocab_size)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
# 输入 -> LL -> softmax -> 输出
# 输出序列中每个词是词汇表中每个词的概率
return self.softmax(self.linear(x))