当前位置: 首页 > news >正文

NLP:Transformer模型构建

本文目录:

  • 一、编码器和解码器的代码实现
  • 二、实例化编码器解码器函数
  • 三、代码运行结果

前言:前面讲解了Transformer的各个部分,本文讲解Transformer模型整体构建。

简单来说,Transformer标准结构包括6个编码器和6个解码器,另外包括1个输入层和1个输出层

一、编码器和解码器的代码实现

# 定义EncoderDecoder类
class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, source_embed, target_embed, generator):super().__init__()# encoder:编码器对象self.encoder = encoder# decoder:解码器的对象self.decoder = decoder# source_embed:源语言输入部分的对象:wordEmbedding+PositionEncodingself.source_embed = source_embed# target_embed:目标语言输入部分的对象:wordEmbedding+PositionEncodingself.target_embed = target_embed# generator:输出层对象self.generator = generatordef forward(self, source, target, source_mask1, source_mask2, target_mask):# source:源语言的输入,形状--[batch_size, seq_len]-->[2, 4]# target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6]# source_mask1:padding mask:作用在编码器端多头自注意力机制-->[head, source_seq_len, source_seq_len]-->[8, 4, 4]# source_mask2:padding mask:作用在解码器端多头注意力机制-->[head, target_seq_len, source_seq_len]-->[8, 6, 4]# target_mask:sentence mask:作用在解码器端多头自注意力机制-->[head, target_seq_len, target_seq_len]-->[8, 6, 6]# 1.将原始的source源语言的输入,形状--[batch_size, seq_len]-->[2, 4]送入编码器输入部分变成--[2,4,512]# encode_word_embed:wordEmbedding+PositionEncodingencode_word_embed = self.source_embed(source)# 2. encode_word_embed以及source_mask1送入编码器得到编码之后的结果:encoder_output-->[2, 4, 512]encoder_output = self.encoder(encode_word_embed, source_mask1)# 3. target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6] 送入解码器输入部分变成--[2,6,512]decode_word_embed = self.target_embed(target)# 4. 将decode_word_embed,encoder_output,source_mask2,target_mask送入解码器decoder_output = self.decoder(decode_word_embed, encoder_output, source_mask2, target_mask)# 5.将decoder_output送入输出层output = self.generator(decoder_output)return output

二、实例化编码器解码器函数

def dm_transformer():# 1.实例化编码器对象# 实例化多头注意力机制的对象mha = MutiHeadAttention(embed_dim=512, head=8, dropout_p=0.1)# 实例化前馈全连接层对象ff = FeedForward(d_model=512, d_ff=1024)encoder_layer = EncoderLayer(size=512, self_atten=mha, ff=ff, dropout_p=0.1)encoder = Encoder(layer=encoder_layer, N=6)# 2.实例化解码器对象self_attn = copy.deepcopy(mha)src_attn = copy.deepcopy(mha)feed_forward = copy.deepcopy(ff)decoder_layer = DecoderLayer(size=512, self_attn=self_attn, src_attn=src_attn, feed_forward=feed_forward, dropout_p=0.1)decoder = Decoder(layer=decoder_layer, N=6)# 3.源语言输入部分的对象:wordEmbedding+PositionEncoding# 经过Embedding层vocab_size = 1000d_model = 512encoder_embed = Embeddings(vocab_size=vocab_size, d_model=d_model)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)dropout_p = 0.1encoder_pe = PositionEncoding(d_model=d_model, dropout_p=dropout_p)source_embed = nn.Sequential(encoder_embed, encoder_pe)# 4.目标语言输入部分的对象:wordEmbedding+PostionEncoding# 经过Embedding层decoder_embed = copy.deepcopy(encoder_embed)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)decoder_pe = copy.deepcopy(encoder_pe)target_embed = nn.Sequential(decoder_embed, decoder_pe)# 5.实例化输出对象generator = Generator(d_model=512, vocab_size=2000)# 6.实例化EncoderDecoder对象transformer = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)print(transformer)# 7.准备数据source = torch.tensor([[1, 2, 3, 4],[2, 5, 6, 10]])target = torch.tensor([[1, 20, 3, 4, 19, 30],[21, 5, 6, 10, 80,38]])source_mask1 = torch.zeros(8, 4, 4)source_mask2 = torch.zeros(8, 6, 4)target_mask = torch.zeros(8, 6, 6)result = transformer(source, target, source_mask1, source_mask2, target_mask)print(f'transformer模型最终的输出结果--》{result}')print(f'transformer模型最终的输出结果--{result.shape}')

三、代码运行结果

# 根据Transformer结构图构建的最终模型结构
EncoderDecoder((encoder): Encoder((layers): ModuleList((0): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(src_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(tgt_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(generator): Generator((proj): Linear(in_features=512, out_features=11))
)

如果代码有不懂,可参看此前文章,谢谢阅读,今天分享结束。

http://www.dtcms.com/a/334414.html

相关文章:

  • 数字分类:机器学习经典案例解析
  • 通过rss订阅小红书,程序员将小红书同步到自己的github主页
  • MCU软件架构---RAM分区设计原则(四)
  • PyTorch生成式人工智能——使用MusicGen生成音乐
  • 二叉树的三种遍历方法
  • List容器:特性与操作使用指南
  • VS Code配置MinGW64编译GLPK(GNU Linear Programming Kit)开源库
  • 实现Android图片手势缩放功能的完整自定义View方案,结合了多种手势交互功能
  • 纸板制造制胶工艺学习记录4
  • Redis集群设计实战:从90%缓存命中率看高并发系统优化
  • Windows常见文件夹cache的作用还有其他缓存类型文件夹的作用
  • backward怎么计算的是torch.tensor(2.0, requires_grad=True)变量的梯度
  • 【论文阅读】Multimodal Graph Contrastive Learning for Multimedia-based Recommendation
  • Linux 下 安装 matlab 2025A
  • 机器学习——CountVectorizer将文本集合转换为 基于词频的特征矩阵
  • 软件的终极:为70亿人编写70亿个不同的软件
  • C++面试题及详细答案100道( 31-40 )
  • SysTick寄存器(嘀嗒定时器实现延时)
  • cPanel Python 应用部署流程
  • 记录一下第一次patch kernel的经历
  • CSV 生成 Gantt 甘特图
  • 2^{-53} 单位舍入误差、机器精度、舍入的最大相对误差界限
  • 【QGIS数据篇】QGIS 3.40 栅格计算器经典实用公式全集
  • 高并发场景下如何避免重复支付
  • 17.3 全选购物车
  • 双椒派E2000D开发板LED驱动开发实战指南
  • 线程回收与线程间通信
  • [Python 基础课程]抽象类
  • 强化学习入门教程(附学习文档)
  • (第十七期)HTML图像标签详解:从入门到精通