当前位置: 首页 > news >正文

基于Transformer的机器翻译——模型篇

1.模型结构

本案例整体采用transformer论文中提出的结构,部分设置做了调整。transformer网络结构介绍可参考博客——入门级别的Transformer模型介绍,这里着重介绍其代码实现。
模型的整体结构,包括词嵌入层,位置编码,编码器,解码器、输出层部分。

2.词嵌入层

词嵌入层用于将token转化为词向量,该层可直接调用nn模块中的Embedding方法。该方法主要包括两个参数,分别表示词表的大小(vocab_size)和词嵌入的维度(emb_size),同时为了训练更稳定,加入了缩放因子dk\sqrt {d_k}dk,代码如下:

class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 词嵌入层:将词索引映射到emb_size维的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 记录嵌入维度(用于缩放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 将词索引转换为词向量,并乘以√emb_size(缩放,稳定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

3.位置编码

位置编码层用于给序列添加位置信息,解决自注意力机制无法感知序列顺序的问题。公式为:
PE(pos,2i)=sin(pos10002id)PE(pos,2i)=sin(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i)=sin(1000d2ipos)
PE(pos,2i+1)=cos(pos10002id)PE(pos,2i+1)=cos(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i+1)=cos(1000d2ipos)
代码表示如下:

class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 计算位置编码的衰减因子(控制正弦/余弦函数的频率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置编码矩阵(形状:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶数列用正弦函数填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇数列用余弦函数填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 调整维度(添加批次维度,便于与词嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout层(正则化,防止过拟合)self.dropout = nn.Dropout(dropout)# 注册为缓冲区(模型保存/加载时自动处理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 将词嵌入向量与位置编码相加,并应用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])

4.编码器

由于编码器部分是通过堆叠多个子编码器层所构成的,子编码器包括:多头自注意力层、残差连接与归一化、前馈网络三部分,该部分代码全部被封装成TransformerEncoderLayer函数中,使用时只需要传递相应超参数即可,如词嵌入维度、多头注意力的头数、前馈网络的隐含层维度,代码实现为:

# 定义编码器层(单头注意力→多头注意力→前馈网络)
encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力的头数dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度
)
# 堆叠多层编码器层形成完整编码器
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

5.解码器

解码器同编码器类似,代码可以表述为:

# 定义解码器层(掩码多头注意力→编码器-解码器多头注意力→前馈网络)
decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力头数(与编码器一致)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度
)
# 堆叠多层解码器层形成完整解码器
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

6.输出层

输出通过线性层得到每个单词的得分,可直接通过Linear层直接实现。

7.大体代码

基于上述介绍,完整代码如下:

from torch.nn import (TransformerEncoder, TransformerDecoder,TransformerEncoderLayer, TransformerDecoderLayer)class Seq2SeqTransformer(nn.Module):"""基于Transformer的序列到序列翻译模型(日中机器翻译核心模块)包含编码器(处理源语言序列)和解码器(生成目标语言序列)"""def __init__(self, num_encoder_layers: int, num_decoder_layers: int,emb_size: int, src_vocab_size: int, tgt_vocab_size: int,dim_feedforward: int = 512, dropout: float = 0.1):"""初始化Transformer模型参数和组件:param num_encoder_layers: 编码器层数(论文中通常为6,此处根据计算资源调整):param num_decoder_layers: 解码器层数(与编码器层数一致):param emb_size: 词嵌入维度(对应Transformer的d_model,需与多头注意力维度匹配):param src_vocab_size: 源语言(日语)词表大小:param tgt_vocab_size: 目标语言(中文)词表大小:param dim_feedforward: 前馈网络隐藏层维度(通常为4*d_model):param dropout:  dropout概率(用于正则化,防止过拟合)"""super(Seq2SeqTransformer, self).__init__()# 定义编码器层(单头注意力→多头注意力→前馈网络)encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力的头数(需满足 emb_size % nhead == 0)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度)# 堆叠多层编码器层形成完整编码器self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)# 定义解码器层(掩码多头注意力→编码器-解码器多头注意力→前馈网络)decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 输入特征维度(与词嵌入维度一致)nhead=NHEAD,            # 多头注意力头数(与编码器一致)dim_feedforward=dim_feedforward  # 前馈网络隐藏层维度)# 堆叠多层解码器层形成完整解码器self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)# 生成器:将解码器输出映射到目标词表(预测每个位置的目标词)self.generator = nn.Linear(emb_size, tgt_vocab_size)# 源语言词嵌入层(将词索引转换为连续向量)self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)# 目标语言词嵌入层(与源语言共享嵌入层可提升效果,此处未共享)self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)# 位置编码层(注入序列位置信息,解决Transformer的位置无关性)self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,tgt_mask: Tensor, src_padding_mask: Tensor,tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):"""前向传播(训练时使用教师强制,输入完整目标序列):param src: 源语言序列张量(形状:[seq_len, batch_size]):param trg: 目标语言序列张量(形状:[seq_len, batch_size]):param src_mask: 源序列注意力掩码(形状:[seq_len, seq_len],全0表示无掩码):param tgt_mask: 目标序列掩码(下三角掩码,防止关注未来词):param src_padding_mask: 源序列填充掩码(标记<pad>位置,形状:[batch_size, seq_len]):param tgt_padding_mask: 目标序列填充掩码(标记<pad>位置,形状:[batch_size, seq_len]):param memory_key_padding_mask: 编码器输出的填充掩码(与src_padding_mask一致):return: 目标序列的词表概率分布(形状:[seq_len, batch_size, tgt_vocab_size])"""# 源序列处理:词嵌入 + 位置编码src_emb = self.positional_encoding(self.src_tok_emb(src))# 目标序列处理:词嵌入 + 位置编码(训练时使用教师强制,输入完整目标序列)tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))# 编码器处理源序列,生成记忆向量(memory)memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)# 解码器利用记忆向量生成目标序列outs = self.transformer_decoder(tgt_emb,                # 目标序列嵌入(含位置信息)memory,                 # 编码器输出的记忆向量tgt_mask,               # 目标序列掩码(防止未来词)None,                   # 编码器-解码器注意力掩码(此处未使用)tgt_padding_mask,       # 目标序列填充掩码(忽略<pad>)memory_key_padding_mask # 记忆向量填充掩码(与源序列填充掩码一致))# 通过生成器输出目标词表的概率分布return self.generator(outs)def encode(self, src: Tensor, src_mask: Tensor):"""编码源序列(推理时单独调用,生成编码器记忆向量):param src: 源语言序列张量(形状:[seq_len, batch_size]):param src_mask: 源序列注意力掩码(形状:[seq_len, seq_len]):return: 编码器输出的记忆向量(形状:[seq_len, batch_size, emb_size])"""return self.transformer_encoder(self.positional_encoding(self.src_tok_emb(src)),  # 源序列嵌入+位置编码src_mask  # 源序列注意力掩码)def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):"""解码目标序列(推理时逐步生成目标词):param tgt: 当前已生成的目标序列前缀(形状:[current_seq_len, batch_size]):param memory: 编码器输出的记忆向量(形状:[seq_len, batch_size, emb_size]):param tgt_mask: 目标序列掩码(下三角掩码,防止关注未来词):return: 解码器输出(形状:[current_seq_len, batch_size, emb_size])"""return self.transformer_decoder(self.positional_encoding(self.tgt_tok_emb(tgt)),  # 目标前缀嵌入+位置编码memory,  # 编码器记忆向量tgt_mask  # 目标前缀掩码(仅允许关注已生成部分))
class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 计算位置编码的衰减因子(控制正弦/余弦函数的频率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置编码矩阵(形状:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶数列用正弦函数填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇数列用余弦函数填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 调整维度(添加批次维度,便于与词嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout层(正则化,防止过拟合)self.dropout = nn.Dropout(dropout)# 注册为缓冲区(模型保存/加载时自动处理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 将词嵌入向量与位置编码相加,并应用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 词嵌入层:将词索引映射到emb_size维的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 记录嵌入维度(用于缩放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 将词索引转换为词向量,并乘以√emb_size(缩放,稳定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

结语

至此,模型已完成搭建,后续博客将继续介绍模型训练部分的内容,希望本篇博客能够对你理解transformer有所帮助!

http://www.dtcms.com/a/334485.html

相关文章:

  • 力扣面试150(57/100)
  • 罗技MX Anywhere 2S鼠标修复记录
  • RocketMq面试集合
  • Redis--day6--黑马点评--商户查询缓存
  • 极简工具箱:安卓工具箱合集
  • redis的key过期删除策略和内存淘汰机制
  • Python爬虫实战:研究pygalmesh,构建Thingiverse平台三维网格数据处理系统
  • 记录Linux的指令学习
  • ktg-mes 改造成 Saas 系统
  • 后量子密码算法ML-DSA介绍及开源代码实现
  • 343整数拆分
  • 实例分割-动手学计算机视觉13
  • MQ积压如何处理
  • ABAP AMDP 是一项什么技术?
  • 深入理解Java虚拟机(JVM):架构、内存管理与性能调优
  • MongoDB 聚合提速 3 招:$lookup 管道、部分索引、时间序列集合(含可复现实验与 explain 统计)
  • 片料矫平机·第四篇
  • Element Plus 中 el-input 限制为数值输入的方法
  • 暴雨服务器:以定制化满足算力需求多样化
  • 深入剖析跳表:高效搜索的动态数据结构
  • 【测试工具】OnDo SIP Server--轻松搭建一个语音通话服务器
  • 社保、医保、个税、公积金纵向横向合并 python3
  • 深入理解 Vue Router
  • Centos7.9安装Dante
  • 04时间复杂度计算方法
  • Python 桌面应用形态后台管理系统的技术选型与方案报告
  • Linux系统之lslogins 命令详解
  • vector 手动实现 及遇到的各种细节问题
  • 深入剖析 TOTP 算法:基于时间的一次性密码生成机制
  • Golang分布式事务处理方案