当前位置: 首页 > news >正文

基于 PyTorch 从零实现 Transformer 模型:从核心组件到训练推理全流程

目录

一、Transformer 整体架构概览

二、核心组件实现详解

1. 位置编码(Positional Encoding)

计算公式

2. 多头注意力机制(Multi-Head Attention)

3. 前馈神经网络(Feed-Forward Neural Network,FFN)

4. 编码器(Encoder)、

核心功能

5. 解码器(Decoder)

三、Transformer 模型组装

四、模型训练流程

五、模型推理流程

 核心特点

六、总结


Transformer 模型作为 NLP 领域的里程碑成果,凭借自注意力机制实现了并行化计算,在机器翻译、文本生成等任务中表现卓越。本文将基于 PyTorch 代码,详细解析 Transformer 的核心组件实现,并完整展示模型训练与推理的全流程。

一、Transformer 整体架构概览

Transformer 采用 "编码器 - 解码器" 架构,整体结构如下:

  • 编码器(Encoder):接收输入序列,通过多层自注意力和前馈网络提取特征
  • 解码器(Decoder):结合编码器输出和自身输入,生成目标序列
  • 投影层(Projection):将解码器输出映射到目标词汇表空间

本文实现的 Transformer 代码结构清晰,主要包含以下模块:

├── transformer.py      # 模型整体架构
├── encoder.py          # 编码器实现
├── decoder.py          # 解码器实现
├── MHA.py              # 多头注意力机制
├── FFN.py              # 前馈神经网络
├── position.py         # 位置编码
├── mask.py             # 注意力掩码
├── 训练.py             # 模型训练流程
└── 预测.py             # 模型推理流程

二、核心组件实现详解

1. 位置编码(Positional Encoding)

Transformer模型使用的是自注意力机制,它本身并不能理解序列中元素的顺序,因此Transformer模型本身不具有处理序列顺序的能力,需通过位置编码注入序列位置信息。

位置编码通过为序列中每个单词的嵌入向量添加一个与其位置相关的向量来实现。这个位置向量与单词的嵌入向量具有相同的维度,使得两者可以通过加法或其他方式结合起来。

计算公式

位置编码可以有多种实现方式,Transformer原始论文中提出的位置编码是通过正弦和余弦函数来计算的,这样做的好处是能够让模型学习到相对位置信息,因为这些函数对位置的偏移是可预测的。对于序列中的每个位置pos,和每个维度 i ,位置编码 ( pos , i ) 是这样计算的:

PE_{pos,2i}=sin(pos/10000^{2i/d_{model}})\\ PE_{pos,2i+1}=cos(pos/10000^{2i/d_{model}}

其中,pos 是位置,i 是维度索引${d}_{model}/2$${d}_{model}$ 是嵌入向量的维度,PE代表整个词的数组

本文采用正弦余弦函数实现:

# position.py 核心代码
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(p=dropout)# 初始化位置编码矩阵pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)# 计算频率项div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 偶数维度用sin,奇数维度用cospe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)  # 扩展为[1, max_len, d_model]self.register_buffer('pe', pe)  # 注册为非参数缓冲区def forward(self, x):# x: [batch_size, seq_len, d_model]x = x + self.pe[:, :x.size(1), :]  # 注入位置信息return self.dropout(x)

原理:通过不同频率的正弦余弦函数区分不同位置,位置越近的词,位置编码越相似。

2. 多头注意力机制(Multi-Head Attention)

多头注意力在自注意力的基础上,将Q、K、V拆分为多个子空间(“多头”),每个子空间独立计算自注意力,最后将所有子空间的结果拼接并线性转换,得到最终输出。

作用

  • 单一注意力机制可能的偏差,通过多个子空间捕捉不同类型的关联(如一个头部关注语法依赖,另一个关注语义关联)。
  • ​等价于在不同表示子空间中学习注意力,提升模型对复杂依赖的建模能力。
# MHA.py 核心代码
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads  # 每个头的维度# Q、K、V的线性变换矩阵self.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)  # 输出线性变换self.attention = Attention()  # 注意力计算self.dropout = nn.Dropout(0.1)self.layer_norm = nn.LayerNorm(d_model)  # 层归一化def forward(self, enc_inputs, dec_inputs, mask=None):res = dec_inputs  # 残差连接batch_size = enc_inputs.size(0)# 线性变换并分多头:[batch, seq_len, d_model] → [batch, heads, seq_len, d_k]Q = self.w_q(dec_inputs).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(enc_inputs).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(enc_inputs).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 计算注意力att_out = self.attention(Q, K, V, mask)# 多头结果拼接:[batch, heads, seq_len, d_k] → [batch, seq_len, d_model]att_out = att_out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)att_out = self.w_o(att_out)# 残差连接+层归一化att_out = self.dropout(att_out)return self.layer_norm(att_out + res)

注意力计算细节

  • 注意力分数:scores = (Q·K^T) / √d_k(缩放点积)
  • 掩码处理:通过masked_fill_将无效位置(如填充符、未来词)设为极小值,避免其参与注意力计算

3. 前馈神经网络(Feed-Forward Neural Network,FFN)

前馈神经网络对每个位置进行独立的非线性变换,由两层线性网络和 ReLU 激活组成,数据单向流动(输入层→隐藏层→输出层),无反馈或循环连接,属于有向无环图(DAG)

核心组件

  • 输入层:接收原始数据(如图像像素、文本向量)。

  • 隐藏层:1层或多层非线性变换(常用ReLU、Sigmoid激活函数)。

  • 输出层:根据任务类型设计(如分类用Softmax,回归用线性输出)。

FFN可以表示为如下形式:

  1. 第一层变换:将自注意力层的输出通过一个线性层(全连接层)变换,增加模型的非线性表示能力。  
                                   \text{FFN}(x)=max(0, xW_1+b_1)W_2+b_2

    其中,W_1和W_2是权重矩阵,b_1和 b_2是偏置项,{max}(0, dot)表示ReLU激活函数。

  2. 激活函数:在两个线性变换之间,使用ReLU或类似的非线性激活函数增加模型处理复杂非线性关系的能力。

  3. 第二层变换:经过激活函数处理的数据再次通过一个线性层变换,以生成最终的输出。

第一层会将输入的向量升维,第二层将向量重新降维。这样子就可以学习到更加抽象的特征

# FFN.py 核心代码
class FFN(nn.Module):def __init__(self, d_model, d_ff):super().__init__()self.ffn = nn.Sequential(nn.Linear(d_model, d_ff),  # 升维nn.ReLU(),nn.Dropout(0.1),nn.Linear(d_ff, d_model)   # 降维)self.dropout = nn.Dropout(0.1)self.layer_norm = nn.LayerNorm(d_model)def forward(self, x):res = x  # 残差连接x = self.ffn(x)x = self.dropout(x)return self.layer_norm(x + res)  # 残差+层归一化

作用:对注意力输出进行非线性变换,增强模型拟合能力。

4. 编码器(Encoder)、

编码器的任务是将输入序列编码成一个固定长度的向量表示(即上下文向量)。

编码器由 N 个相同的编码器层堆叠而成,每个编码器层包含:

  • 自注意力机制子层 (输入同时作为 Q、K、V)
  • 前馈神经网络
核心功能
  • 特征提取:将原始输入(如图像、文本、语音)转换为更具语义的抽象特征。

  • 降维/压缩:减少数据维度(如PCA的神经网络实现)。

  • 信息编码:生成适合下游任务(如分类、生成)的中间表示。

# encoder.py 核心代码
class EncoderLayer(nn.Module):def __init__(self, d_model, d_ff, n_heads):super().__init__()self.multi_head_attention = MultiHeadAttention(d_model, n_heads)self.feed_forward = FFN(d_model, d_ff)def forward(self, enc_inputs, mask=None):# 自注意力计算enc_outputs = self.multi_head_attention(enc_inputs, enc_inputs, mask)# 前馈网络enc_outputs = self.feed_forward(enc_outputs)return enc_outputsclass Encoder(nn.Module):def __init__(self, vocab_size, d_model, d_ff, n_heads, n_layers):super().__init__()self.embedding = nn.Embedding(vocab_size, d_model)  # 词嵌入self.position_encoding = PositionalEncoding(d_model)  # 位置编码# 堆叠N个编码器层self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, n_heads)for _ in range(n_layers)])def forward(self, enc_inputs):# 词嵌入+位置编码enc_outputs = self.embedding(enc_inputs)enc_outputs = self.position_encoding(enc_outputs)# 生成填充掩码(忽略填充符[PAD])mask = att_pad_mask(enc_inputs, enc_inputs)# 经过所有编码器层for layer in self.layers:enc_outputs = layer(enc_outputs, mask)return enc_outputs

5. 解码器(Decoder)

解码器是神经网络中与编码器配对的组件,负责将编码后的隐式表示转换回目标输出形式。

解码器同样由 N 个解码器层堆叠而成,每个解码器层包含:

  • 掩蔽多头自注意力(防止关注未来词)
  • 交叉注意力(以编码器输出为 K、V,解码器输出为 Q)
  • 前馈神经网络

核心功能

  • 数据重建:将编码后的低维特征恢复为原始数据维度(如自编码器)。

  • 生成任务:从潜空间生成新数据(如图像生成、文本翻译)。

  • 序列预测:逐步输出序列结果(如机器翻译、语音合成)。

# decoder.py 核心代码
class DecoderLayer(nn.Module):def __init__(self, d_model, d_ff, n_heads):super().__init__()self.masked_multi_head = MultiHeadAttention(d_model, n_heads)  # 掩蔽自注意力self.cross_multi_head = MultiHeadAttention(d_model, n_heads)   # 交叉注意力self.feed_forward = FFN(d_model, d_ff)def forward(self, enc_outputs, dec_inputs, mask_self=None, mask_cross=None):# 掩蔽自注意力(仅关注已生成的词)dec_outputs = self.masked_multi_head(dec_inputs, dec_inputs, mask_self)# 交叉注意力(结合编码器输出)dec_outputs = self.cross_multi_head(enc_outputs, dec_outputs, mask_cross)# 前馈网络dec_outputs = self.feed_forward(dec_outputs)return dec_outputs

掩码处理

  • 填充掩码(att_pad_mask):忽略输入中的填充符
  • 序列掩码(att_sub_mask):解码器中防止关注未来位置的词,通过上三角矩阵实现

三、Transformer 模型组装

将编码器、解码器和投影层组装为完整 Transformer 模型的过程可简化为三个核心步骤:

  1. 组件初始化
    分别创建编码器(处理输入序列)、解码器(生成目标序列)和投影层(映射到词汇表),确保各组件维度匹配(如模型维度d_model保持一致)。

  2. 前向传播连接

    • 输入序列先经编码器处理,得到包含上下文信息的编码输出
    • 解码器接收编码器输出和目标序列(训练时),生成带解码信息的输出
    • 投影层将解码器输出从d_model维度映射到目标词汇表大小,得到每个位置的词汇概率分布
  3. 输出格式调整
    调整最终输出维度,便于计算交叉熵损失(通常展平为二维:[batch_size*seq_len, vocab_size]

# transformer.py 核心代码
class Transformer(nn.Module):def __init__(self, enc_vocab_size, dec_vocab_size, d_model, d_ff, n_heads, n_layers):super().__init__()self.encoder = Encoder(enc_vocab_size, d_model, d_ff, n_heads, n_layers)self.decoder = Decoder(dec_vocab_size, d_model, d_ff, n_heads, n_layers)self.projection = nn.Linear(d_model, dec_vocab_size)  # 映射到目标词汇表def forward(self, enc_inputs, dec_inputs):# 编码器输出enc_outputs = self.encoder(enc_inputs)# 解码器输出dec_outputs = self.decoder(enc_inputs, enc_outputs, dec_inputs)# 投影到词汇表outputs = self.projection(dec_outputs)return outputs.view(-1, outputs.size(2))  # 调整维度用于计算损失

四、模型训练流程

训练过程采用交叉熵损失和 SGD 优化器,原因如下:

  1. 交叉熵损失
    适用于分类任务(此处为词汇预测),能有效衡量模型输出的概率分布与真实标签(one-hot 形式)的差异,通过最大化正确词汇的对数概率引导模型学习。

  2. SGD 优化器
    作为基础优化器,通过随机梯度下降更新参数,结合动量(momentum)可加速收敛并减少震荡,适合处理大规模数据,在早期 Transformer 实现中是常用选择(后续多被 Adam 替代,但 SGD 实现简单且适合演示基础训练流程)。

核心代码如下:

# 训练.py 核心代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, d_ff, n_heads, n_layers).to(device)# 损失函数(忽略填充符)和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练循环
for epoch in range(100):for idx, (enc_inputs, dec_inputs, dec_labels) in enumerate(loader):enc_inputs, dec_inputs, dec_labels = enc_inputs.to(device), dec_inputs.to(device), dec_labels.to(device)optimizer.zero_grad()  # 清空梯度outputs = model(enc_inputs, dec_inputs)  # 前向传播loss = criterion(outputs, dec_labels.view(-1))  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch: {epoch+1:04d}, loss = {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'transformer.pth')

五、模型推理流程

推理阶段采用贪婪解码(Greedy Decoding)生成目标序列。

贪婪编码器是一种基于局部最优策略的编码方法,在每一步决策时选择当前最优解(如最大概率、最小误差),而不考虑全局最优性。常见于信号处理、数据压缩和深度学习中的序列生成任务(如机器翻译、语音识别)。

 核心特点
  • 局部最优性:每一步仅选择当前最佳选项,不回溯或全局规划。

  • 低计算成本:相比动态规划或束搜索(Beam Search),计算复杂度低。

  • 可能陷入次优解:因忽略长远依赖,可能导致最终结果非全局最优。

# 预测.py 核心代码
def greedy_decoder(model, enc_input, start_symbol):enc_outputs = model.encoder(enc_input)  # 编码器输出dec_input = torch.zeros(1, 0).type_as(enc_input.data)  # 初始化解码器输入next_symbol = start_symbol  # 起始符号(如'S')terminal = Falsewhile not terminal:# 追加下一个符号到解码器输入dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1)# 解码器前向传播dec_outputs = model.decoder(enc_input, enc_outputs, dec_input)projected = model.projection(dec_outputs)  # 投影到词汇表# 选择概率最大的词作为下一个符号next_symbol = projected.squeeze(0).max(dim=-1)[1][-1].item()# 若遇到终止符则停止if next_symbol == tgt_vocab["E"]:terminal = Truereturn dec_input

六、总结

本文围绕 Transformer 模型的实现与应用展开,从核心组件到完整流程进行了系统解析。首先介绍了 Transformer 的整体架构,包括编码器、解码器和投影层的基本构成;随后详细阐述了位置编码、多头注意力、前馈神经网络等核心组件的实现原理与代码细节,解释了各部分在模型中的具体作用;接着展示了如何将这些组件有机组合成完整的 Transformer 模型,并呈现了模型训练(采用交叉熵损失和 SGD 优化器)与推理(使用贪婪解码)的关键流程。

通过模块化的代码实现,清晰展现了 Transformer 从输入处理到输出生成的全过程。该实现保留了原始模型的核心思想,同时结构简洁易懂,便于理解 Transformer 的工作机制。

http://www.dtcms.com/a/319862.html

相关文章:

  • Java 大视界 -- Java 大数据在智能安防门禁系统中的人员行为分析与异常事件预警(385)
  • nvm安装,nvm管理node版本
  • Java设计模式总结
  • 【设计模式精解】什么是代理模式?彻底理解静态代理和动态代理
  • Vue自定义流程图式菜单解决方案
  • [激光原理与应用-171]:测量仪器 - 能量型 - 激光能量计(单脉冲能量测量)
  • DicomObjects COM 8.XX
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-文章列表
  • [TIP 2025] 轻量级光谱注意力LSA,极致优化,减少99.8%参数,提升性能!
  • kafka安装与参数配置
  • MPC-in-the-Head 转换入门指南
  • 抖音、快手、视频号等多平台视频解析下载 + 磁力嗅探下载、视频加工(提取音频 / 压缩等)
  • 【性能测试】---测试工具篇(jmeter)
  • Java垃圾回收(GC)探析
  • 图像理解、计算机视觉相关名词解释
  • 最新教程 | CentOS 7 内网环境 Nginx + ECharts 页面离线部署手册(RPM 安装方式)
  • yolo目标检测技术:基础概念(一)
  • Vscode Data Wrangler 数据查看和处理工具
  • Docker容器技术详解
  • 施易德智慧门店管理系统:零售品牌出海的高效引擎
  • mysql 索引失效分析
  • Cesium粒子系统模拟风场动态效果
  • 国内使用 npm 时配置镜像源
  • 网络安全等级保护(等保)2.0 概述
  • 树莓派下载安装miniconda(linux版小anaconda)
  • 【奔跑吧!Linux 内核(第二版)】第6章:简单的字符设备驱动(一)
  • 解决 Nginx 反代中 proxy_ssl_name 环境变量失效问题:网页能打开但登录失败
  • 3深度学习Pytorch-神经网络--全连接神经网络、数据准备(构建数据类Dataset、TensorDataset 和数据加载器DataLoader)
  • TCP 如何保证可靠性
  • Linux openssl、openssh 升级 保留旧版本