当前位置: 首页 > news >正文

你的第一个Transformer模型:从零实现并训练一个迷你ChatBot

点击AladdinEdu,同学们用得起的【H卡】算力平台”,注册即送-H卡级别算力80G大显存按量计费灵活弹性顶级配置学生更享专属优惠


引言:破除神秘感,拥抱核心思想

在人工智能的浪潮中,Transformer模型无疑是一颗最璀璨的明珠。从GPT系列到BERT,从翻译到对话,它的身影无处不在。然而,对于许多初学者而言,Transformer常常被冠以“复杂”、“晦涩难懂”的标签,厚厚的论文和错综复杂的结构图让人望而却步。

今天,我们的目标就是亲手撕掉这层神秘的面纱。我们将不使用任何高级的深度学习框架(如Hugging Face的Transformers库),而是仅借助PyTorch提供的基础张量操作和神经网络模块,从零开始,一行代码一行代码地构建一个完整的Transformer模型。最终,我们会在一个小型数据集上训练它,让它成为一个能进行简单对话的迷你ChatBot。

相信我,当你跟着本文完成整个流程,并看到你的模型开始生成回复时,你会对Self-Attention、位置编码等核心概念有一种“顿悟”般的感觉。这不仅是一次编程练习,更是一次深入理解现代AI核心架构的绝佳旅程。

让我们开始吧!


第一部分:Transformer架构总览

在深入代码之前,我们快速回顾一下Transformer的核心设计。其最初在论文《Attention Is All You Need》中提出,完全基于Attention机制,摒弃了传统的循环和卷积结构。

一个典型的Transformer包含一个编码器(Encoder)和一个解码器(Decoder)。对于我们的聊天机器人任务,编码器负责理解输入的问句,解码器则负责生成输出的答句。

  • 编码器:由N个(原文是6个)相同的层堆叠而成。每层包含两个子层:

    1. 多头自注意力机制(Multi-Head Self-Attention)
    2. 前馈神经网络(Position-wise Feed-Forward Network)
      每个子层周围都有一个残差连接(Residual Connection)层归一化(Layer Normalization)
  • 解码器:同样由N个相同的层堆叠而成。每层包含三个子层:

    1. 掩码多头自注意力机制(Masked Multi-Head Self-Attention):确保解码时只能看到当前位置之前的信息,防止“偷看”未来答案。
    2. 多头编码器-解码器注意力机制(Multi-Head Encoder-Decoder Attention):帮助解码器关注输入序列中的相关信息。
    3. 前馈神经网络
      同样,每个子层也都有残差连接和层归一化。

此外,模型最开始有输入嵌入层位置编码,最后有输出线性层Softmax

我们的代码实现将严格遵循这个结构。我们将自底向上地构建它。


第二部分:核心模块代码实现

我们首先实现最核心、最关键的几个模块。

1. Self-Attention 与 Scaled Dot-Product Attention

Self-Attention是Transformer的灵魂。它的目的是让序列中的任何一个字都能够与序列中的所有其他字进行交互,从而更好地捕捉上下文信息。

Scaled Dot-Product Attention的计算公式如下:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V

其中:

  • Q (Query):查询矩阵,代表当前要关注的词。
  • K (Key):键矩阵,代表序列中所有待被查询的词。
  • V (Value):值矩阵,代表序列中所有词的实际信息。
  • d_k:Key向量的维度,缩放因子dk\sqrt{d_k}dk用于防止点积过大导致softmax梯度消失。
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass ScaledDotProductAttention(nn.Module):"""Scaled Dot-Product Attention"""def __init__(self, dropout_rate=0.1):super(ScaledDotProductAttention, self).__init__()self.dropout = nn.Dropout(dropout_rate)def forward(self, Q, K, V, attn_mask=None):# Q, K, V 的形状: [batch_size, n_heads, seq_len, d_k or d_v]# d_k = d_model / n_headsd_k = K.size()[-1]# 计算注意力分数 QK^T / sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)# 如果提供了注意力掩码,应用它(将mask为1的位置置为一个极小的值,如-1e9)if attn_mask is not None:scores = scores.masked_fill(attn_mask == 1, -1e9)# 对最后一维(seq_len维)进行softmax,得到注意力权重attn_weights = F.softmax(scores, dim=-1)# 可选:应用dropoutattn_weights = self.dropout(attn_weights)# 将注意力权重乘以V,得到最终的输出output = torch.matmul(attn_weights, V) # [batch_size, n_heads, seq_len, d_v]return output, attn_weights

2. Multi-Head Attention

多头注意力机制将模型分为多个“头”,让每个头去关注序列中不同的方面(例如,有的头关注语法关系,有的头关注语义关系),最后将各头的输出合并起来。

class MultiHeadAttention(nn.Module):"""Multi-Head Attention mechanism"""def __init__(self, d_model, n_heads, dropout_rate=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0, "d_model must be divisible by n_heads"self.d_model = d_modelself.n_heads = n_headsself.d_k = d_model // n_heads # 每个头的维度self.d_v = d_model // n_heads# 线性投影层,用于生成Q, K, Vself.W_Q = nn.Linear(d_model, d_model)self.W_K = nn.Linear(d_model, d_model)self.W_V = nn.Linear(d_model, d_model)self.W_O = nn.Linear(d_model, d_model) # 输出投影层self.attention = ScaledDotProductAttention(dropout_rate)self.dropout = nn.Dropout(dropout_rate)self.layer_norm = nn.LayerNorm(d_model)def forward(self, Q, K, V, attn_mask=None):# 残差连接residual = Qbatch_size = Q.size(0)# 线性投影并分头# (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads, d_k) -> (batch_size, n_heads, seq_len, d_k)q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)# 如果需要,扩展attn_mask以匹配多头形状if attn_mask is not None:attn_mask = attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len] 广播到所有头# 应用ScaledDotProductAttentioncontext, attn_weights = self.attention(q_s, k_s, v_s, attn_mask=attn_mask)# 将各头的输出拼接起来# (batch_size, n_heads, seq_len, d_v) -> (batch_size, seq_len, n_heads * d_v) = (batch_size, seq_len, d_model)context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)# 输出投影output = self.W_O(context)output = self.dropout(output)# 残差连接和层归一化output = self.layer_norm(output + residual)return output, attn_weights

3. Position-wise Feed-Forward Network

这是一个简单的前馈神经网络,对每个位置(词)的特征进行独立变换。它通常包含两个线性层和一个ReLU激活函数。

class PositionWiseFFN(nn.Module):"""Position-wise Feed-Forward Network"""def __init__(self, d_model, d_ff, dropout_rate=0.1):super(PositionWiseFFN, self).__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout_rate)self.layer_norm = nn.LayerNorm(d_model)def forward(self, x):residual = xx = self.w_1(x)x = F.relu(x)x = self.dropout(x)x = self.w_2(x)x = self.dropout(x)# 残差连接和层归一化x = self.layer_norm(x + residual)return x

4. Positional Encoding(位置编码)

由于Transformer没有循环和卷积结构,它无法感知序列的顺序。因此,我们需要手动注入位置信息。这里我们使用论文中的正弦和余弦函数编码。

class PositionalEncoding(nn.Module):"""Implement the PE function."""def __init__(self, d_model, max_seq_len=5000):super(PositionalEncoding, self).__init__()# 创建一个足够长的位置编码矩阵pe = torch.zeros(max_seq_len, d_model)position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) # [max_seq_len, 1]div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 对矩阵的偶数和奇数索引分别用正弦和余弦函数pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]# 注册为一个缓冲区(buffer),它将是模型的一部分,但不被视为可训练参数self.register_buffer('pe', pe)def forward(self, x):# x 的形状: [batch_size, seq_len, d_model]x = x + self.pe[:, :x.size(1), :]return x

第三部分:组装编码器与解码器层

有了上面的积木,我们现在可以搭建编码器层和解码器层了。

1. 编码器层(Encoder Layer)

一个编码器层包含一个多头自注意力子层和一个前馈网络子层。

class EncoderLayer(nn.Module):"""A single layer of the encoder."""def __init__(self, d_model, n_heads, d_ff, dropout_rate=0.1):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)self.ffn = PositionWiseFFN(d_model, d_ff, dropout_rate)def forward(self, enc_input, enc_self_attn_mask=None):# 自注意力子层enc_output, attn_weights = self.self_attn(enc_input, enc_input, enc_input, attn_mask=enc_self_attn_mask)# 前馈网络子层enc_output = self.ffn(enc_output)return enc_output, attn_weights

2. 解码器层(Decoder Layer)

一个解码器层包含三个子层:掩码自注意力、编码器-解码器注意力和前馈网络。

class DecoderLayer(nn.Module):"""A single layer of the decoder."""def __init__(self, d_model, n_heads, d_ff, dropout_rate=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)self.enc_dec_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)self.ffn = PositionWiseFFN(d_model, d_ff, dropout_rate)def forward(self, dec_input, enc_output, dec_self_attn_mask=None, dec_enc_attn_mask=None):# 掩码自注意力子层dec_output, self_attn_weights = self.self_attn(dec_input, dec_input, dec_input, attn_mask=dec_self_attn_mask)# 编码器-解码器注意力子层# Q 来自解码器,K, V 来自编码器输出dec_output, enc_dec_attn_weights = self.enc_dec_attn(dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask)# 前馈网络子层dec_output = self.ffn(dec_output)return dec_output, self_attn_weights, enc_dec_attn_weights

第四部分:构建完整的Transformer模型

现在,我们将嵌入层、位置编码、编码器栈、解码器栈以及最终的输出层组合在一起。

class Transformer(nn.Module):"""The complete Transformer model."""def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout_rate=0.1):super(Transformer, self).__init__()self.d_model = d_model# 输入和输出嵌入层,共享权重通常效果更好,但这里我们先分开self.enc_embedding = nn.Embedding(src_vocab_size, d_model)self.dec_embedding = nn.Embedding(tgt_vocab_size, d_model)self.pos_encoding = PositionalEncoding(d_model, max_seq_len)# 编码器和解码器堆叠self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout_rate) for _ in range(n_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout_rate) for _ in range(n_layers)])# 最终的线性层和softmaxself.linear = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout_rate)def forward(self, src_input, tgt_input, src_mask=None, tgt_mask=None):# 编码器部分enc_output = self.enc_embedding(src_input) * math.sqrt(self.d_model)enc_output = self.pos_encoding(enc_output)enc_output = self.dropout(enc_output)for layer in self.encoder_layers:enc_output, _ = layer(enc_output, enc_self_attn_mask=src_mask)# 解码器部分dec_output = self.dec_embedding(tgt_input) * math.sqrt(self.d_model)dec_output = self.pos_encoding(dec_output)dec_output = self.dropout(dec_output)for layer in self.decoder_layers:dec_output, _, _ = layer(dec_output, enc_output, dec_self_attn_mask=tgt_mask)# 输出投影output = self.linear(dec_output)# Softmax在损失函数中计算,这里直接返回logitsreturn output

第五部分:数据准备与训练

1. 选择一个迷你数据集

为了快速实验,我们使用一个非常小的对话数据集。例如,我们可以手动创建一个:

Q: Hi
A: Hello!
Q: What's your name?
A: I'm ChatBot.
Q: How are you?
A: I'm fine, thank you.
... (再多几十组)

或者使用Cornell Movie Dialogs Corpus的一小部分。我们需要构建一个词汇表,并将句子转换为ID序列。

2. 构建词汇表和DataLoader

# 伪代码:构建词汇表
# sentences = [所有Q和A的句子]
# vocab = {'<pad>':0, '<sos>':1, '<eos>':2, ...} 构建词汇字典
# src_ids = [[vocab[word] for word in sentence.split()] for sentence in src_sentences]# 使用PyTorch的DataLoader和Dataset
from torch.utils.data import Dataset, DataLoaderclass ChatDataset(Dataset):def __init__(self, src_sentences, tgt_sentences, vocab, max_len):self.src_sentences = src_sentencesself.tgt_sentences = tgt_sentencesself.vocab = vocabself.max_len = max_lendef __len__(self):return len(self.src_sentences)def __getitem__(self, idx):src_seq = self.sentence_to_ids(self.src_sentences[idx])tgt_seq = self.sentence_to_ids(self.tgt_sentences[idx])# 添加起始符<sos>和结束符<eos>tgt_input = [self.vocab['<sos>']] + tgt_seqtgt_output = tgt_seq + [self.vocab['<eos>']]# 填充到最大长度src_seq = self.pad_seq(src_seq, self.max_len)tgt_input = self.pad_seq(tgt_input, self.max_len)tgt_output = self.pad_seq(tgt_output, self.max_len)return torch.LongTensor(src_seq), torch.LongTensor(tgt_input), torch.LongTensor(tgt_output)# ... (实现sentence_to_ids和pad_seq方法)

3. 创建注意力掩码和训练循环

我们需要创建两种掩码:

  1. 填充掩码(Padding Mask):遮盖掉<pad>符号,防止注意力机制关注这些无意义的位置。
  2. 序列掩码(Sequence Mask):用于解码器的自注意力,防止解码时看到未来的信息(一个下三角矩阵)。
def create_padding_mask(seq, pad_idx):# seq: [batch_size, seq_len]return (seq == pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len] 便于广播def create_look_ahead_mask(seq_len):# 创建一个下三角矩阵,对角线及其以上为0,以下为1mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]

训练循环 的标准流程:准备数据、计算模型输出、计算损失(带忽略<pad>的CrossEntropyLoss)、反向传播、优化器步进。

# 初始化模型、优化器、损失函数
model = Transformer(src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, n_layers=6, d_ff=2048, max_seq_len=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx) # 忽略padding位置的损失for epoch in range(num_epochs):model.train()for batch in dataloader:src, tgt_in, tgt_out = batchsrc_mask = create_padding_mask(src, pad_idx)# 解码器掩码:填充掩码 + 序列掩码tgt_padding_mask = create_padding_mask(tgt_in, pad_idx)tgt_look_ahead_mask = create_look_ahead_mask(tgt_in.size(1))tgt_mask = torch.logical_or(tgt_padding_mask, tgt_look_ahead_mask) # 组合掩码optimizer.zero_grad()output = model(src, tgt_in, src_mask, tgt_mask)# output: [batch_size, tgt_len, tgt_vocab_size]# tgt_out: [batch_size, tgt_len]loss = criterion(output.view(-1, output.size(-1)), tgt_out.view(-1))loss.backward()optimizer.step()# ... 每个epoch结束后可以打印损失或进行验证

第六部分:推理与对话生成

训练完成后,我们使用贪心搜索(Greedy Search)来生成回复。

def predict(model, src_sentence, vocab, inv_vocab, max_len, device):model.eval()with torch.no_grad():# 将源句子转换为IDsrc_ids = sentence_to_ids(src_sentence, vocab)src_tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device) # [1, src_len]# 初始化目标输入,起始为<sos>tgt_ids = [vocab['<sos>']]for i in range(max_len):tgt_tensor = torch.LongTensor(tgt_ids).unsqueeze(0).to(device) # [1, current_tgt_len]# 创建掩码src_mask = create_padding_mask(src_tensor, pad_idx)tgt_mask = create_look_ahead_mask(len(tgt_ids))# 预测下一个词output = model(src_tensor, tgt_tensor, src_mask, tgt_mask)next_word_logits = output[0, -1, :] # 最后一个位置的输出next_word_id = torch.argmax(next_word_logits, dim=-1).item()tgt_ids.append(next_word_id)if next_word_id == vocab['<eos>']:break# 将ID序列转换回句子,忽略<sos>和<eos>predicted_sentence = ids_to_sentence(tgt_ids[1:-1], inv_vocab)return predicted_sentence# 示例用法
# vocab: 词汇表,inv_vocab: 反向词汇表(id到word)
# response = predict(model, "Hello there", vocab, inv_vocab, max_len=20, device='cpu')
# print(response)

总结与展望

恭喜你!你已经从零开始实现并训练了一个Transformer模型。这个过程无疑充满挑战,但它极大地深化了你对Self-Attention、位置编码、掩码等核心概念的理解。

我们的迷你ChatBot虽然简单,但它已经具备了Transformer架构的所有精髓。你可以通过以下方式进一步提升它:

  1. 使用更大更高质量的数据集
  2. 调整超参数d_model, n_heads, n_layers, d_ff等。
  3. 实现更高级的解码策略:如Beam Search。
  4. 尝试预训练:在海量文本上先进行无监督预训练,再在我们的对话数据上进行微调。
  5. 加入更多技巧:如标签平滑、学习率预热等。

希望这次“手撕”Transformer的经历让你不再觉得它神秘莫测。它只是一个精心设计的神经网络,其力量源于对序列数据中复杂依赖关系的强大建模能力。现在,你已经掌握了它的蓝图,可以自由地去探索、修改和创新了!

(注意:由于篇幅和可运行性限制,本文代码为示例性质,可能需要一些调试和修改才能完全运行。建议在Jupyter Notebook或Colab中分模块逐步测试。)


点击AladdinEdu,同学们用得起的【H卡】算力平台”,注册即送-H卡级别算力80G大显存按量计费灵活弹性顶级配置学生更享专属优惠


文章转载自:

http://6UCeTBGZ.hkgcx.cn
http://XQqXNB7R.hkgcx.cn
http://2IATnPbM.hkgcx.cn
http://xA2ny4Y5.hkgcx.cn
http://ZozAlkAI.hkgcx.cn
http://YHPJiwY7.hkgcx.cn
http://nMu1dHnM.hkgcx.cn
http://8HZjxG4n.hkgcx.cn
http://8lrIC3iQ.hkgcx.cn
http://eW47P6Dh.hkgcx.cn
http://rqV0WbiE.hkgcx.cn
http://VjCRME3B.hkgcx.cn
http://ePGQspUC.hkgcx.cn
http://ZZRRFRqn.hkgcx.cn
http://stQqAacp.hkgcx.cn
http://jqf2pHQM.hkgcx.cn
http://CwYHB3K9.hkgcx.cn
http://JBm3H37k.hkgcx.cn
http://gRV3YDNq.hkgcx.cn
http://ffPWejDJ.hkgcx.cn
http://U3jh7ElF.hkgcx.cn
http://q0Cbx4n6.hkgcx.cn
http://dCsMl0Ta.hkgcx.cn
http://SWR5jlFa.hkgcx.cn
http://BdH4LTr0.hkgcx.cn
http://AKRC98U1.hkgcx.cn
http://cBgtUQs9.hkgcx.cn
http://1ivysoiD.hkgcx.cn
http://JoLTiBbt.hkgcx.cn
http://NuG8wnKy.hkgcx.cn
http://www.dtcms.com/a/387167.html

相关文章:

  • JVM工具全景指南
  • 储能电站监控与能量管理系统(EMS)技术规范
  • 代码随想录刷题——栈和队列篇(三)
  • 尺寸最小32.768KHZ有源晶振SIT1572
  • Python文件写入安全指南:处理不存在文件的完整解决方案
  • 网络层认识——IP协议
  • 软考中级习题与解答——第七章_数据库系统(1)
  • 立创·庐山派K230CanMV开发板的进阶学习——特征检测
  • 使用 Nano-banana 的 API 方式
  • 【原理】为什么React框架的传统递归无法被“中断”从而选用链式fiber结构?
  • Redis网络模型分析:从单线程到多线程的网络架构演进
  • 刷题日记0916
  • 5.PFC闭环控制仿真
  • 三层网络结构接入、汇聚、核心交换层,应该怎么划分才对?
  • Std::Future大冒险:穿越C++并发宇宙的时空胶囊
  • 《LINUX系统编程》笔记p13
  • Spring Cloud-面试知识点(组件、注册中心)
  • 2.2 定点数的运算 (答案见原书 P93)
  • 使用数据断点调试唤醒任务时__state的变化
  • 力扣周赛困难-3681. 子序列最大 XOR 值 (线性基)
  • Spring IOC 与 Spring AOP
  • 【FreeRTOS】队列API全家桶
  • 【Docker项目实战】使用Docker部署Cup容器镜像更新工具
  • (笔记)内存文件映射mmap
  • springboot传输文件,下载文件
  • 基于51单片机的出租车计价器霍尔测速设计
  • 【笔记】Agent应用开发与落地全景
  • C++ STL底层原理系列学习路线规划
  • LAN口和WAN口
  • Dify + Bright Data MCP:从实时影音数据到可落地的智能体生产线