当前位置: 首页 > news >正文

Transformer(四)---解码器部分实现、输出部分实现及模型搭建

目录

一、解码器介绍

二、解码器层

三、解码器

四、输出部分实现

五、模型构建

5.1 编码器-解码器的代码实现

5.2 Transformer模型构建过程代码分析


一、解码器介绍

由N个解码器层堆叠而成
每个解码器层由三个子层连接结构组成
第一个子层连接结构包括一个(掩码)多头自注意力子层和规范化层以及一个残差连接
第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接

二、解码器层

解码器层的作用:
作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.

source_mask(源数据掩码)的主要作用是防止解码器关注到编码器输出中无效的 padding 区域,确保模型只关注源序列(如输入文本)中有意义的内容。

target_mask(目标掩码):用于解码器的自注意力(self_attn),防止解码器关注未来的 tokens(实现因果语言模型的特性)。

class DecoderLayer(nn.Module):def __init__(self,size,self_attn,src_attn,feed_forward,dropout):super(DecoderLayer,self).__init__()# 词嵌入维度尺寸大小self.size = size# 自注意力机制层对象 q=k=vself.self_attn = self_attn# 一般注意力机制对象 q!=k=vself.src_attn = src_attn# 前馈全连接层对象self.feed_forward = feed_forward# clones 3子层连接结构self.sublayers = clones(SublayerConnection(size,dropout),3)def forward(self,x,memory,source_mask,target_mask):# forward函数的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量mermory, 以及源数据掩码张量和目标数据掩码张量.m = memory# 数据经过子层连接结构1x = self.sublayers[0](x,lambda x:self.self_attn(x,x,x,target_mask))# 数据经过子层连接结构2x = self.sublayers[1](x,lambda x:self.src_attn(x,m,m,source_mask))# 数据经过子层连接结构3x = self.sublayers[2](x,self.feed_forward)return x

三、解码器

根据编码器的结果以及上一次预测的结果, 对下一次可能出现的'值'进行特征表示.

class Decoder(nn.Module):def __init__(self,layer,N):super(Decoder,self).__init__()# clones N个解码器层self.layers = clones(layer,N)# 定义规范化层self.norm = LayerNorm(layer.size)def forward(self,x,memory,source_mask,target_mask):for layer in self.layers:x = layer(x,memory,source_mask,target_mask)return self.norm(x)

四、输出部分实现

输出部分包含:
线性层
softmax层

线性层的作用:
通过对上一步的线性变化得到指定维度的输出, 也就是转换维度的作用.

softmax层的作用:
使最后一维的向量中的数字缩放到0-1的概率值域内, 并满足他们的和为1.

class Generator(nn.Module):def __init__(self,vocab,d_model):# vocab 线性输出尺寸大小# d_model 线性层输入特征尺寸大小super(Generator,self).__init__()self.linear = nn.Linear(d_model,vocab)def forward(self,x):x = F.log_softmax(self.linear(x),dim=-1)return x

五、模型构建

通过上面的小节, 我们已经完成了所有组成部分的实现, 接下来就来实现完整的编码器-解码器结构.

Transformer总体架构图:

5.1 编码器-解码器的代码实现

class EncoderDecoder(nn.Module):def __init__(self,encoder,decoder,source_embed,target_embed,generator):"""初始化函数中有5个参数, 分别是编码器对象, 解码器对象, 源数据嵌入函数, 目标数据嵌入函数,  以及输出部分的类别生成器对象"""super(EncoderDecoder,self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = source_embedself.tgt_embed =target_embedself.generator = generatordef forward(self,source,target,source_mask,target_mask):"""在forward函数中,有四个参数, source代表源数据, target代表目标数据, source_mask和target_mask代表对应的掩码张量"""# 在函数中,将source,source_mask传入编码函数,得到结果后,# 与source_mask,target,target_mask一同传给解码函数return self.generator(self.decode(self.encode(source,source_mask),source_mask,target,target_mask))def encode(self,source,source_mask):return self.encoder(self.src_embed(source),source_mask)def decode(self,memory,source_mask,target,target_mask):return self.decoder(self.tgt_embed(target),memory,source_mask,target_mask)

5.2 Transformer模型构建过程代码分析

make_model函数初始化一个一个组件对象(轮子对象),调用EncoderDecoder()函数。

def make_model(source_vocab,target_vocab,N=6,d_model=512,d_ff=2048,head=8,dropout=0.1):c = copy.deepcopy# 实例化多头注意力层对象attn = MultiHeadAttention(head=8,embedding_dim=d_model,dropout=dropout)# 实例化前馈全连接层ff = PositionwiseFeedForward(d_model=d_model,d_ff=d_ff,dropout=dropout)# 实例化 位置编码器对象positionposition = PositionalEncoding(d_model=d_model,dropout=dropout)# 构建EncoderDecoder对象model = EncoderDecoder(# 编码器对象Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),N),# 解码器对象Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),N),# 词嵌入层 位置编码器层容器nn.Sequential(Embeddings(source_vocab,d_model),c(position)),# 词嵌入层 位置编码器层容器nn.Sequential(Embeddings(target_vocab,d_model),c(position)),# 输出层对象Generator(target_vocab,d_model))for p in model.parameters():if p.dim()>1:nn.init.xavier_normal_(p)return model
def dm_test_make_model():source_vocab = 500target_vocab = 1000N=6my_transformer_model = make_model(source_vocab,target_vocab,N=6,d_model=512,d_ff=2048,head=8,dropout=0.1)print(my_transformer_model)# 假设源数据与目标数据相同,实际中并不相同source = target = torch.LongTensor([[1,2,3,8],[3,4,1,8]])# 假设src_mask与tgt——mask相同,实际中并不相同source_mask = target_mask = torch.zeros(8,4,4)my_data = my_transformer_model(source,target,source_mask,target_mask)print('mydata.shape--->',my_data.shape)print('mydata--->',my_data)
dm_test_make_model()

http://www.dtcms.com/a/450027.html

相关文章:

  • 网站开发毕业设计任务书注册推广赚钱一个30元
  • 邯郸网络广播电视台北京网站seo技术厂家
  • leetcode 695 岛屿的最大面积
  • LLaVA-NeXT-Interleave论文阅读
  • 邢台企业网站制作公司中建国际建设有限公司网站
  • 长春火车站防疫要求做网站都要用到框架吗
  • 集合进阶 - HashMap 篇
  • 从原图到线图再到反推:网络图几何与拓扑的结合分析
  • Lua下载和安装教程(附安装包)
  • JAVA实验课程第五次作业分析与代码示例
  • 龙口网站制作公司深圳知名设计公司有哪些
  • 网站数据修改网页界面设计的起源
  • 东莞建设网站官网住房和城乡wordpress 如何修改like和dislike
  • Gopher二次编码原因解析
  • 【ARM汇编语言基础】-数据处理指令(七)
  • 汇编与反汇编
  • 福州建设网站shopee怎么注册开店
  • 建立网站站点的目的贵州二级站seo整站优化排名
  • 阳江做网站多少钱企业网站推广方法有哪些
  • sm2025 模拟赛11 (2025.10.5)
  • python镜像源配置
  • 4.寻找两个正序数组的中位数-二分查找
  • 理解CC++异步IO编程:Epoll入门
  • wordpress房屋网站模板微信小程序
  • 阿里网站建设视频教程WordPress云媒体库
  • SpringCloud 入门 - Nacos 配置中心
  • Windows 下使用 Claude Code CLI 启动 Kimi
  • 网站推广的基本方式抖音特效开放平台官网
  • 湖南网站排名wordpress插件seo
  • WindowsKyLin:nginx安装与配置