Transformer(四)---解码器部分实现、输出部分实现及模型搭建
目录
一、解码器介绍
二、解码器层
三、解码器
四、输出部分实现
五、模型构建
5.1 编码器-解码器的代码实现
5.2 Transformer模型构建过程代码分析
一、解码器介绍
由N个解码器层堆叠而成
每个解码器层由三个子层连接结构组成
第一个子层连接结构包括一个(掩码)多头自注意力子层和规范化层以及一个残差连接
第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
二、解码器层
解码器层的作用:
作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程.
source_mask
(源数据掩码)的主要作用是防止解码器关注到编码器输出中无效的 padding 区域,确保模型只关注源序列(如输入文本)中有意义的内容。
target_mask
(目标掩码):用于解码器的自注意力(self_attn
),防止解码器关注未来的 tokens(实现因果语言模型的特性)。
class DecoderLayer(nn.Module):def __init__(self,size,self_attn,src_attn,feed_forward,dropout):super(DecoderLayer,self).__init__()# 词嵌入维度尺寸大小self.size = size# 自注意力机制层对象 q=k=vself.self_attn = self_attn# 一般注意力机制对象 q!=k=vself.src_attn = src_attn# 前馈全连接层对象self.feed_forward = feed_forward# clones 3子层连接结构self.sublayers = clones(SublayerConnection(size,dropout),3)def forward(self,x,memory,source_mask,target_mask):# forward函数的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量mermory, 以及源数据掩码张量和目标数据掩码张量.m = memory# 数据经过子层连接结构1x = self.sublayers[0](x,lambda x:self.self_attn(x,x,x,target_mask))# 数据经过子层连接结构2x = self.sublayers[1](x,lambda x:self.src_attn(x,m,m,source_mask))# 数据经过子层连接结构3x = self.sublayers[2](x,self.feed_forward)return x
三、解码器
根据编码器的结果以及上一次预测的结果, 对下一次可能出现的'值'进行特征表示.
class Decoder(nn.Module):def __init__(self,layer,N):super(Decoder,self).__init__()# clones N个解码器层self.layers = clones(layer,N)# 定义规范化层self.norm = LayerNorm(layer.size)def forward(self,x,memory,source_mask,target_mask):for layer in self.layers:x = layer(x,memory,source_mask,target_mask)return self.norm(x)
四、输出部分实现
输出部分包含:
线性层
softmax层
线性层的作用:
通过对上一步的线性变化得到指定维度的输出, 也就是转换维度的作用.
softmax层的作用:
使最后一维的向量中的数字缩放到0-1的概率值域内, 并满足他们的和为1.
class Generator(nn.Module):def __init__(self,vocab,d_model):# vocab 线性输出尺寸大小# d_model 线性层输入特征尺寸大小super(Generator,self).__init__()self.linear = nn.Linear(d_model,vocab)def forward(self,x):x = F.log_softmax(self.linear(x),dim=-1)return x
五、模型构建
通过上面的小节, 我们已经完成了所有组成部分的实现, 接下来就来实现完整的编码器-解码器结构.
Transformer总体架构图:
5.1 编码器-解码器的代码实现
class EncoderDecoder(nn.Module):def __init__(self,encoder,decoder,source_embed,target_embed,generator):"""初始化函数中有5个参数, 分别是编码器对象, 解码器对象, 源数据嵌入函数, 目标数据嵌入函数, 以及输出部分的类别生成器对象"""super(EncoderDecoder,self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = source_embedself.tgt_embed =target_embedself.generator = generatordef forward(self,source,target,source_mask,target_mask):"""在forward函数中,有四个参数, source代表源数据, target代表目标数据, source_mask和target_mask代表对应的掩码张量"""# 在函数中,将source,source_mask传入编码函数,得到结果后,# 与source_mask,target,target_mask一同传给解码函数return self.generator(self.decode(self.encode(source,source_mask),source_mask,target,target_mask))def encode(self,source,source_mask):return self.encoder(self.src_embed(source),source_mask)def decode(self,memory,source_mask,target,target_mask):return self.decoder(self.tgt_embed(target),memory,source_mask,target_mask)
5.2 Transformer模型构建过程代码分析
make_model函数初始化一个一个组件对象(轮子对象),调用EncoderDecoder()函数。
def make_model(source_vocab,target_vocab,N=6,d_model=512,d_ff=2048,head=8,dropout=0.1):c = copy.deepcopy# 实例化多头注意力层对象attn = MultiHeadAttention(head=8,embedding_dim=d_model,dropout=dropout)# 实例化前馈全连接层ff = PositionwiseFeedForward(d_model=d_model,d_ff=d_ff,dropout=dropout)# 实例化 位置编码器对象positionposition = PositionalEncoding(d_model=d_model,dropout=dropout)# 构建EncoderDecoder对象model = EncoderDecoder(# 编码器对象Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),N),# 解码器对象Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),N),# 词嵌入层 位置编码器层容器nn.Sequential(Embeddings(source_vocab,d_model),c(position)),# 词嵌入层 位置编码器层容器nn.Sequential(Embeddings(target_vocab,d_model),c(position)),# 输出层对象Generator(target_vocab,d_model))for p in model.parameters():if p.dim()>1:nn.init.xavier_normal_(p)return model
def dm_test_make_model():source_vocab = 500target_vocab = 1000N=6my_transformer_model = make_model(source_vocab,target_vocab,N=6,d_model=512,d_ff=2048,head=8,dropout=0.1)print(my_transformer_model)# 假设源数据与目标数据相同,实际中并不相同source = target = torch.LongTensor([[1,2,3,8],[3,4,1,8]])# 假设src_mask与tgt——mask相同,实际中并不相同source_mask = target_mask = torch.zeros(8,4,4)my_data = my_transformer_model(source,target,source_mask,target_mask)print('mydata.shape--->',my_data.shape)print('mydata--->',my_data)
dm_test_make_model()