PyTorch中Transformer模型的搭建

2020-10-26 14:34:06 浏览数 (1)

PyTorch最近版本更新很快,1.2/1.3/1.4几乎是连着出,其中: 1.3/1.4版本主要是新增并完善了PyTorchMobile移动端部署模块和模型量化模块。 而1.2版中一个重要的更新就是把加入了NLP领域中炙手可热的Transformer模型,这里记录一下PyTorchTransformer模型的用法(代码写于1.2版本,没有在1.3/1.4版本测试)。

1. 简介


也许是为了更方便地搭建BertGPT-2之类的NLP模型,PyTorchTransformer相关的模型分为nn.TransformerEncoderLayernn.TransformerDecoderLayernn.LayerNorm等几个部分。搭建模型的时候不一定都会用到, 比如fastai中的Transformer模型就只用到了encoder部分,没有用到decoder

至于WordEmbeddingPositionEncoding两个部分需要自己另外实现。

WordEmbedding可以直接使用PyTorch自带的nn.Embedding层。

PositionEncoding层的花样就多了,不同的模型下面有不同的PositionEncoding,比如Transformer的原始论文Attention is all you need中使用的是无参数的PositionEncodingBert中使用的是带有学习参数的PositionEncoding

在本文中介绍的是参考Transformer原始论文实现的Sequence2sequence形式的Transformer模型。

2. Sequence2sequence形式的Transformer模型搭建:


2.1 无可学习参数的PositionEncoding层

无参数的PositionEncoding计算速度快,还可以减小整个模型的尺寸,据说在有些任务中,效果与有参数的接近。

代码语言:javascript复制
class PositionalEncoding(nn.Module):
    def __init__(self, d_model,dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x   self.pe[:x.size(0), :]
        return self.dropout(x)

2.2 有可学习参数的PositionEncoding层

我曾在一个序列预测任务(非NLP)里面对比过两种PositionEncoding层,发现带有参数的PositionEncoding层效果明显比没有参数的PositionEncoding要好。

带参数的PositionEncoding层的定义更为简单,直接继承一个nn.Embedding,再续上一个dropout就可以了。因为nn.Embedding中包含了一个可以按索引取向量的权重矩阵weight

代码语言:javascript复制
class LearnedPositionEncoding(nn.Embedding):
    def __init__(self,d_model, dropout = 0.1,max_len = 5000):
        super().__init__(max_len, d_model)
        self.dropout = nn.Dropout(p = dropout)

    def forward(self, x):
        weight = self.weight.data.unsqueeze(1)
        x = x   weight[:x.size(0),:]
        return self.dropout(x)

2.3 Sequence2sequence模型

将embedding、position_encoding、encoder和decoder拼接起来,就可以构成一个完整的sequence2sequence形式的Transformer模型了。

代码语言:javascript复制
class S2sTransformer(nn.Module):

    def __init__(self,vocab_size,position_enc,d_model = 512,nhead = 8,num_encoder_layers=6,
                 num_decoder_layers=6,dim_feedforward=2048,dropout=0.1):
        super(S2sTransformer,self).__init__()

        # Preprocess
        self.embedding = nn.Embedding(vocab_size,d_model)
        self.pos_encoder_src = position_enc(d_model=512)
        # tgt
        self.pos_encoder_tgt = position_enc(d_model=512)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model,nhead,dim_feedforward,dropout)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer,num_encoder_layers,encoder_norm)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model,nhead,dim_feedforward,dropout)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer,num_decoder_layers,decoder_norm)
        self.output_layer = nn.Linear(d_model,vocab_size)

        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead


    def forward(self, src,tgt,src_mask = None,tgt_mask = None,
                memory_mask = None,src_key_padding_mask = None,
                tgt_key_padding_mask = None,memory_key_padding_mask = None):

        # word embedding
        src = self.embedding(src)
        tgt = self.embedding(tgt)

        # shape check
        if src.size(1) != tgt.size(1):
            raise RuntimeError("the batch number of src and tgt must be equal")
        if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")

        # position encoding
        src = self.pos_encoder_src(src)
        tgt = self.pos_encoder_tgt(tgt)

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        output = self.output_layer(output)
        # return output
        return softmax(output,dim = 2)


    def generate_square_subsequent_mask(self, sz):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

模型搭建好了之后,就可以按照Sequence2sequence的训练方式进行训练了, 唯一需要注意的就是Transformerforward过程是并行的,与基于RNNSequence2sequence模型稍有不同。

训练过程可以参考PyTorch官网提供的chatbot的教程

0 人点赞