NLP:Transformer模型构建
本文目录:
- 一、编码器和解码器的代码实现
- 二、实例化编码器解码器函数
- 三、代码运行结果
前言:前面讲解了Transformer的各个部分,本文讲解Transformer模型整体构建。
简单来说,Transformer标准结构包括6个编码器和6个解码器,另外包括1个输入层和1个输出层。
一、编码器和解码器的代码实现
# 定义EncoderDecoder类
class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, source_embed, target_embed, generator):super().__init__()# encoder:编码器对象self.encoder = encoder# decoder:解码器的对象self.decoder = decoder# source_embed:源语言输入部分的对象:wordEmbedding+PositionEncodingself.source_embed = source_embed# target_embed:目标语言输入部分的对象:wordEmbedding+PositionEncodingself.target_embed = target_embed# generator:输出层对象self.generator = generatordef forward(self, source, target, source_mask1, source_mask2, target_mask):# source:源语言的输入,形状--》[batch_size, seq_len]-->[2, 4]# target:目标语言的输入,形状--》[batch_size, seq_len]-->[2, 6]# source_mask1:padding mask:作用在编码器端多头自注意力机制-->[head, source_seq_len, source_seq_len]-->[8, 4, 4]# source_mask2:padding mask:作用在解码器端多头注意力机制-->[head, target_seq_len, source_seq_len]-->[8, 6, 4]# target_mask:sentence mask:作用在解码器端多头自注意力机制-->[head, target_seq_len, target_seq_len]-->[8, 6, 6]# 1.将原始的source源语言的输入,形状--》[batch_size, seq_len]-->[2, 4]送入编码器输入部分变成--》[2,4,512]# encode_word_embed:wordEmbedding+PositionEncodingencode_word_embed = self.source_embed(source)# 2. encode_word_embed以及source_mask1送入编码器得到编码之后的结果:encoder_output-->[2, 4, 512]encoder_output = self.encoder(encode_word_embed, source_mask1)# 3. target:目标语言的输入,形状--》[batch_size, seq_len]-->[2, 6] 送入解码器输入部分变成--》[2,6,512]decode_word_embed = self.target_embed(target)# 4. 将decode_word_embed,encoder_output,source_mask2,target_mask送入解码器decoder_output = self.decoder(decode_word_embed, encoder_output, source_mask2, target_mask)# 5.将decoder_output送入输出层output = self.generator(decoder_output)return output
二、实例化编码器解码器函数
def dm_transformer():# 1.实例化编码器对象# 实例化多头注意力机制的对象mha = MutiHeadAttention(embed_dim=512, head=8, dropout_p=0.1)# 实例化前馈全连接层对象ff = FeedForward(d_model=512, d_ff=1024)encoder_layer = EncoderLayer(size=512, self_atten=mha, ff=ff, dropout_p=0.1)encoder = Encoder(layer=encoder_layer, N=6)# 2.实例化解码器对象self_attn = copy.deepcopy(mha)src_attn = copy.deepcopy(mha)feed_forward = copy.deepcopy(ff)decoder_layer = DecoderLayer(size=512, self_attn=self_attn, src_attn=src_attn, feed_forward=feed_forward, dropout_p=0.1)decoder = Decoder(layer=decoder_layer, N=6)# 3.源语言输入部分的对象:wordEmbedding+PositionEncoding# 经过Embedding层vocab_size = 1000d_model = 512encoder_embed = Embeddings(vocab_size=vocab_size, d_model=d_model)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)dropout_p = 0.1encoder_pe = PositionEncoding(d_model=d_model, dropout_p=dropout_p)source_embed = nn.Sequential(encoder_embed, encoder_pe)# 4.目标语言输入部分的对象:wordEmbedding+PostionEncoding# 经过Embedding层decoder_embed = copy.deepcopy(encoder_embed)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)decoder_pe = copy.deepcopy(encoder_pe)target_embed = nn.Sequential(decoder_embed, decoder_pe)# 5.实例化输出对象generator = Generator(d_model=512, vocab_size=2000)# 6.实例化EncoderDecoder对象transformer = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)print(transformer)# 7.准备数据source = torch.tensor([[1, 2, 3, 4],[2, 5, 6, 10]])target = torch.tensor([[1, 20, 3, 4, 19, 30],[21, 5, 6, 10, 80,38]])source_mask1 = torch.zeros(8, 4, 4)source_mask2 = torch.zeros(8, 6, 4)target_mask = torch.zeros(8, 6, 6)result = transformer(source, target, source_mask1, source_mask2, target_mask)print(f'transformer模型最终的输出结果--》{result}')print(f'transformer模型最终的输出结果--》{result.shape}')
三、代码运行结果
# 根据Transformer结构图构建的最终模型结构
EncoderDecoder((encoder): Encoder((layers): ModuleList((0): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(src_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(tgt_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(generator): Generator((proj): Linear(in_features=512, out_features=11))
)
如果代码有不懂,可参看此前文章,谢谢阅读,今天分享结束。