PyTorch 深度学习实战(38):注意力机制全面解析(从Seq2Seq到Transformer)
在上一篇文章中,我们探讨了分布式训练实战。本文将深入解析注意力机制的完整发展历程,从最初的Seq2Seq模型到革命性的Transformer架构。我们将使用PyTorch实现2个关键阶段的注意力机制变体,并在机器翻译任务上进行对比实验。
一、注意力机制演进路线
1. 关键模型对比
模型 | 发表年份 | 核心创新 | 计算复杂度 | 典型应用 |
---|---|---|---|---|
Seq2Seq | 2014 | 编码器-解码器架构 | O(n²) | 机器翻译 |
Bahdanau Attention | 2015 | 软注意力机制(动态上下文向量) | O(n²) | 文本生成、语音识别 |
Luong Attention | 2015 | 全局/局部注意力(改进对齐方式) | O(n²) | 语音识别、长文本翻译 |
Transformer | 2017 | 自注意力机制(并行化处理) | O(n²) | 所有序列任务(NLP/CV) |
Sparse Transformer | 2019 | 稀疏注意力(分块处理长序列) | O(n√n) | 长文本生成、基因序列分析 |
MQA | 2023 | 多查询注意力(共享KV减少内存) | O(n log n) | 大模型推理加速 |
GQA | 2024 | 分组查询注意力(平衡精度与效率) | O(n log n) | 工业级大模型部署 |
Flash Attention | 2024 | 分块计算优化KV缓存 | O(n√n) | 超长序列处理(>10k tokens) |
DeepSeek MLA | 2025 | 多头潜在注意力(潜在空间投影) | O(n log n) | 多模态融合、复杂推理任务 |
TPA | 2025 | 张量积分解注意力(动态秩优化) | O(n) | 边缘计算、低资源环境 |
MoBA | 2025 | 混合块注意力(Top-K门控选择) | O(n log n) | 百万级长文本处理 |
ECA | 2025 | 高效通道注意力(参数无关门控) | O(1) | 图像分类、目标检测 |
2. 注意力类型分类
class AttentionTypes:def __init__(self):self.soft_attention = ["加性注意力", "点积注意力"]self.hard_attention = ["随机硬注意力", "最大似然注意力"] self.self_attention = ["标准自注意力", "稀疏自注意力"]self.cross_attention = ["编码器-解码器注意力"]
二、基础注意力机制实现
1. 环境配置
pip install torch matplotlib
2. Luong注意力实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import random
from tqdm import tqdm# 设备配置 - 检查是否有可用的GPU,没有则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 数据预处理部分
def build_tokenizer(text_iter, vocab_size=20000):"""构建分词器参数:text_iter: 文本迭代器vocab_size: 词汇表大小返回:训练好的分词器"""# 使用Unigram模型初始化分词器tokenizer = Tokenizer(models.Unigram())# 使用空格作为预分词器tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()# 配置训练器trainer = trainers.UnigramTrainer(vocab_size=vocab_size,special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"], # 特殊标记unk_token="[UNK]" # 显式设置UNK标记)# 从文本迭代器训练分词器tokenizer.train_from_iterator(text_iter, trainer)# 确保UNK标记在分词器中正确设置if tokenizer.token_to_id("[UNK]") is None:raise ValueError("UNK token not properly initialized in tokenizer")return tokenizerclass TranslationDataset(Dataset):"""翻译数据集类用于加载和处理翻译数据"""def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=100):"""初始化参数:data: 原始数据src_tokenizer: 源语言分词器trg_tokenizer: 目标语言分词器max_len: 最大序列长度"""self.data = dataself.src_tokenizer = src_tokenizerself.trg_tokenizer = trg_tokenizerself.max_len = max_len# 获取UNK标记的IDself.src_unk_id = self.src_tokenizer.token_to_id("[UNK]")self.trg_unk_id = self.trg_tokenizer.token_to_id("[UNK]")if self.src_unk_id is None or self.trg_unk_id is None:raise ValueError("Tokenizers must have [UNK] token")def __len__(self):"""返回数据集大小"""return len(self.data)def __getitem__(self, idx):"""获取单个样本参数:idx: 索引返回:包含源语言和目标语言token ID的字典"""item = self.data[idx]["translation"]# 源语言(中文)处理src_encoded = self.src_tokenizer.encode(item["zh"])# 添加开始和结束标记,并截断到最大长度src_tokens = ["[SOS]"] + src_encoded.tokens[:self.max_len - 2] + ["[EOS]"]# 将token转换为ID,未知token使用UNK IDsrc_ids = [self.src_tokenizer.token_to_id(t) or self.src_unk_id for t in src_tokens]# 目标语言(英文)处理trg_encoded = self.trg_tokenizer.encode(item["en"])trg_tokens = ["[SOS]"] + trg_encoded.tokens[:self.max_len - 2] + ["[EOS]"]trg_ids = [self.trg_tokenizer.token_to_id(t) or self.trg_unk_id for t in trg_tokens]return {"src": torch.tensor(src_ids),"trg": torch.tensor(trg_ids)}def collate_fn(batch):"""批处理函数用于DataLoader中对批次数据进行填充"""src = [item["src"] for item in batch]trg = [item["trg"] for item in batch]return {"src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0), # 用0填充"trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)}# 模型实现部分
class Encoder(nn.Module):"""编码器将输入序列编码为隐藏状态"""def __init__(self, input_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1):"""初始化参数:input_dim: 输入维度(词汇表大小)emb_dim: 词嵌入维度hid_dim: 隐藏层维度n_layers: RNN层数dropout: dropout率"""super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=0) # 词嵌入层# 双向GRUself.rnn = nn.GRU(emb_dim, hid_dim, n_layers,dropout=dropout if n_layers > 1 else 0,bidirectional=True)self.fc = nn.Linear(hid_dim * 2, hid_dim) # 用于合并双向输出的全连接层self.dropout = nn.Dropout(dropout)self.n_layers = n_layersdef forward(self, src):"""前向传播参数:src: 输入序列返回:outputs: 编码器所有时间步的输出hidden: 最后一个时间步的隐藏状态"""# 词嵌入 + dropoutembedded = self.dropout(self.embedding(src)) # [batch_size, src_len, emb_dim]# GRU处理 (需要将batch维度放在第二位)outputs, hidden = self.rnn(embedded.transpose(0, 1)) # outputs: [src_len, batch_size, hid_dim*2]# 处理双向隐藏状态# hidden的形状是[num_layers * num_directions, batch_size, hid_dim]hidden = hidden.view(self.n_layers, 2, -1, self.rnn.hidden_size) # [n_layers, 2, batch_size, hid_dim]hidden = hidden[-1] # 取最后一层 [2, batch_size, hid_dim]hidden = torch.cat([hidden[0], hidden[1]], dim=1) # 合并双向输出 [batch_size, hid_dim*2]hidden = torch.tanh(self.fc(hidden)) # [batch_size, hid_dim]# 扩展以匹配解码器的层数hidden = hidden.unsqueeze(0).repeat(self.n_layers, 1, 1) # [n_layers, batch_size, hid_dim]return outputs, hiddenclass LuongAttention(nn.Module):def __init__(self, hid_dim, method="general"):super().__init__()self.method = methodif method == "general":self.W = nn.Linear(hid_dim, hid_dim, bias=False)elif method == "concat":self.W = nn.Linear(hid_dim * 2, hid_dim, bias=False)self.v = nn.Linear(hid_dim, 1, bias=False)def forward(self, decoder_hidden, encoder_outputs):""" decoder_hidden: [1, batch_size, hid_dim]encoder_outputs: [src_len, batch_size, hid_dim * 2] (bidirectional)"""if self.method == "dot":# 处理双向输出 - 取前向和后向的平均hid_dim = decoder_hidden.size(-1)encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)encoder_outputs = encoder_outputs.mean(dim=2) # [src_len, batch_size, hid_dim]# 计算点积分数scores = torch.matmul(encoder_outputs.transpose(0, 1), # [batch_size, src_len, hid_dim]decoder_hidden.transpose(0, 1).transpose(1, 2) # [batch_size, hid_dim, 1]).squeeze(2) # [batch_size, src_len]# 添加缩放因子scores = scores / (decoder_hidden.size(-1) ** 0.5)elif self.method == "general":# 对于通用注意力,我们需要对解码器隐藏状态进行投影decoder_hidden_proj = self.W(decoder_hidden) # [1, batch_size, hid_dim]# 处理双向输出hid_dim = decoder_hidden.size(-1)encoder_outputs = encoder_outputs.view(encoder_outputs.size(0), encoder_outputs.size(1), 2, hid_dim)encoder_outputs = encoder_outputs.mean(dim=2) # [src_len, batch_size, hid_dim]scores = torch.matmul(encoder_outputs.transpose(0, 1), # [batch_size, src_len, hid_dim]decoder_hidden_proj.transpose(0, 1).transpose(1, 2) # [batch_size, hid_dim, 1]).squeeze(2) # [batch_size, src_len]elif self.method == "concat":# 对于concat,我们可以使用完整的双向输出decoder_hidden = decoder_hidden.repeat(encoder_outputs.size(0), 1, 1) # [src_len, batch_size, hid_dim]energy = torch.cat((decoder_hidden, encoder_outputs), dim=2) # [src_len, batch_size, hid_dim*3]scores = self.v(torch.tanh(self.W(energy))).squeeze(2).t() # [batch_size, src_len]attn_weights = F.softmax(scores, dim=1)# 对于上下文向量,使用原始双向输出context = torch.bmm(attn_weights.unsqueeze(1), # [batch_size, 1, src_len]encoder_outputs.transpose(0, 1) # [batch_size, src_len, hid_dim*2]).squeeze(1) # [batch_size, hid_dim*2]return context, attn_weightsclass Decoder(nn.Module):"""解码器使用注意力机制生成目标序列"""def __init__(self, output_dim, emb_dim, hid_dim, n_layers=1, dropout=0.1, attn_method="general"):"""初始化参数:output_dim: 输出维度(目标词汇表大小)emb_dim: 词嵌入维度hid_dim: 隐藏层维度n_layers: RNN层数dropout: dropout率attn_method: 注意力计算方法"""super().__init__()self.output_dim = output_dimself.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=0) # 词嵌入层# 单向GRUself.rnn = nn.GRU(emb_dim, hid_dim, n_layers,dropout=dropout if n_layers > 1 else 0)self.attention = LuongAttention(hid_dim, attn_method) # 注意力层# 全连接层(根据注意力方法调整输入维度)self.fc = nn.Linear(hid_dim * 3 if attn_method == "concat" else hid_dim * 2, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, encoder_outputs):"""前向传播参数:input: 当前输入tokenhidden: 当前隐藏状态encoder_outputs: 编码器输出返回:prediction: 预测的下一个tokenhidden: 新的隐藏状态attn_weights: 注意力权重"""input = input.unsqueeze(0) # [1, batch_size]embedded = self.dropout(self.embedding(input)) # [1, batch_size, emb_dim]# GRU处理output, hidden = self.rnn(embedded, hidden) # output: [1, batch_size, hid_dim]# 计算注意力context, attn_weights = self.attention(output, encoder_outputs)# 预测下一个token(拼接RNN输出和上下文向量)prediction = self.fc(torch.cat((output.squeeze(0), context), dim=1))return prediction, hidden, attn_weightsclass Seq2Seq(nn.Module):"""序列到序列模型整合编码器和解码器"""def __init__(self, encoder, decoder, device):"""初始化参数:encoder: 编码器实例decoder: 解码器实例device: 计算设备"""super().__init__()self.encoder = encoderself.decoder = decoder# 确保解码器与编码器层数相同assert decoder.rnn.num_layers == encoder.n_layersself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):"""前向传播参数:src: 源序列trg: 目标序列(训练时使用)teacher_forcing_ratio: 教师强制比例返回:所有时间步的输出"""batch_size = src.shape[0]trg_len = trg.shape[1]trg_vocab_size = self.decoder.output_dim# 初始化输出张量outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)# 编码器处理encoder_outputs, hidden = self.encoder(src)# 第一个输入是<SOS>标记input = trg[:, 0]# 逐步生成输出序列for t in range(1, trg_len):# 解码器处理output, hidden, _ = self.decoder(input, hidden, encoder_outputs)outputs[t] = output# 决定是否使用教师强制teacher_force = random.random() < teacher_forcing_ratiotop1 = output.argmax(1)input = trg[:, t] if teacher_force else top1return outputs# 训练与评估函数
def train(model, loader, optimizer, criterion, clip):"""训练函数参数:model: 模型loader: 数据加载器optimizer: 优化器criterion: 损失函数clip: 梯度裁剪阈值返回:平均损失"""model.train()epoch_loss = 0for batch in tqdm(loader, desc="Training"):src = batch["src"].to(device)trg = batch["trg"].to(device)optimizer.zero_grad()output = model(src, trg)# 计算损失(忽略第一个token)output = output[1:].reshape(-1, output.shape[-1])trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(loader)def evaluate(model, loader, criterion):"""评估函数参数:model: 模型loader: 数据加载器criterion: 损失函数返回:平均损失"""model.eval()epoch_loss = 0with torch.no_grad():for batch in tqdm(loader, desc="Evaluating"):src = batch["src"].to(device)trg = batch["trg"].to(device)# 评估时不使用教师强制output = model(src, trg, teacher_forcing_ratio=0)output = output[1:].reshape(-1, output.shape[-1])trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(loader)# 加载数据(使用opus100数据集的中英翻译部分,只取前10000条作为示例)
dataset = load_dataset("./opus100", "en-zh", split="train[:10000]")
# 划分训练集和验证集
train_val = dataset.train_test_split(test_size=0.2)# 构建分词器
def get_text_iter(data, lang="zh"):"""获取文本迭代器用于分词器训练"""for item in data["translation"]:yield item[lang]# 训练中文和英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"]))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)# 初始化模型
INPUT_DIM = len(zh_tokenizer.get_vocab()) # 中文词汇表大小
OUTPUT_DIM = len(en_tokenizer.get_vocab()) # 英文词汇表大小
ENC_EMB_DIM = 512 # 编码器词嵌入维度
DEC_EMB_DIM = 512 # 解码器词嵌入维度
HID_DIM = 1024 # 隐藏层维度
N_LAYERS = 3 # RNN层数
DROP_RATE = 0.3 # dropout率# 创建编码器、解码器和seq2seq模型
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DROP_RATE, "dot")
model = Seq2Seq(encoder, decoder, device).to(device)# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # 学习率调度器
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充标记的损失
CLIP = 5.0 # 梯度裁剪阈值
N_EPOCHS = 20 # 训练轮数# 训练循环
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)valid_loss = evaluate(model, val_loader, criterion)# 保存最佳模型if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'best_model.pt')print(f'Epoch: {epoch + 1:02}')print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):"""翻译函数参数:model: 训练好的模型sentence: 待翻译的句子src_tokenizer: 源语言分词器trg_tokenizer: 目标语言分词器max_len: 最大生成长度返回:翻译结果"""model.eval()# 中文分词并编码tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]src = torch.tensor([src_tokenizer.token_to_id(t) for t in tokens]).unsqueeze(0).to(device)# 初始化目标序列(以<SOS>开始)trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]# 逐步生成目标序列for i in range(max_len):trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device)with torch.no_grad():output = model(src, trg_tensor)# 获取预测的下一个tokenpred_token = output.argmax(2)[-1].item()trg_indexes.append(pred_token)# 如果遇到<EOS>则停止if pred_token == trg_tokenizer.token_to_id("[EOS]"):break# 将ID转换为tokentrg_tokens = [trg_tokenizer.id_to_token(i) for i in trg_indexes]# 去掉<EOS>和<SOS>并返回return ' '.join(trg_tokens[1:-1])# 测试翻译
test_sentences = ["你好世界","深度学习很有趣","今天天气真好"
]print("\n测试翻译结果:")
for sent in test_sentences:translation = translate(model, sent, zh_tokenizer, en_tokenizer)print(f"中文: {sent} -> 英文: {translation}")
输出为:
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 01Train Loss: 6.443 | Val. Loss: 6.351
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 02Train Loss: 6.315 | Val. Loss: 6.400
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.53it/s]
Epoch: 03Train Loss: 6.307 | Val. Loss: 6.406
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.52it/s]
Epoch: 04Train Loss: 6.303 | Val. Loss: 6.469
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00, 1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 05Train Loss: 6.304 | Val. Loss: 6.398
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:31<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 06Train Loss: 6.298 | Val. Loss: 6.421
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.53it/s]
Epoch: 07Train Loss: 6.298 | Val. Loss: 6.459
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.52it/s]
Epoch: 08Train Loss: 6.291 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 09Train Loss: 6.293 | Val. Loss: 6.425
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 10Train Loss: 6.293 | Val. Loss: 6.491
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:26<00:00, 1.21it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 11Train Loss: 6.294 | Val. Loss: 6.467
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:27<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.54it/s]
Epoch: 12Train Loss: 6.295 | Val. Loss: 6.439
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:32<00:00, 1.18it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 13Train Loss: 6.293 | Val. Loss: 6.495
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 14Train Loss: 6.296 | Val. Loss: 6.471
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 15Train Loss: 6.300 | Val. Loss: 6.423
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:30<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 16Train Loss: 6.298 | Val. Loss: 6.458
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 17Train Loss: 6.303 | Val. Loss: 6.510
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:28<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.55it/s]
Epoch: 18Train Loss: 6.305 | Val. Loss: 6.479
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.20it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.56it/s]
Epoch: 19Train Loss: 6.309 | Val. Loss: 6.585
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:29<00:00, 1.19it/s]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:13<00:00, 4.57it/s]
Epoch: 20Train Loss: 6.310 | Val. Loss: 6.515测试翻译结果:
中文: 你好世界 -> 英文: [PAD]
中文: 深度学习很有趣 -> 英文: [PAD]
中文: 今天天气真好 -> 英文: [PAD]
三、Transformer实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import math# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 数据预处理
def build_tokenizer(text_iter, vocab_size=20000):tokenizer = Tokenizer(models.Unigram())tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()trainer = trainers.UnigramTrainer(vocab_size=vocab_size,special_tokens=["[PAD]", "[UNK]", "[SOS]", "[EOS]"], # 确保包含UNKunk_token="[UNK]" # 显式指定UNK token)tokenizer.train_from_iterator(text_iter, trainer)return tokenizerdef get_text_iter(dataset, language="zh"):for item in dataset["translation"]:yield item[language]# 加载数据集
train_data = dataset = load_dataset("opus100", "en-zh")
train_val = train_data["train"].train_test_split(test_size=0.2)# 构建中英文分词器
zh_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "zh"))
en_tokenizer = build_tokenizer(get_text_iter(train_val["train"], "en"))# 词汇表
zh_vocab = zh_tokenizer.get_vocab()
en_vocab = en_tokenizer.get_vocab()class TranslationDataset(Dataset):def __init__(self, data, src_tokenizer, trg_tokenizer, max_len=50):self.data = dataself.src_tokenizer = src_tokenizerself.trg_tokenizer = trg_tokenizerself.max_len = max_len# 获取所有必须的ID(确保不为None)self.src_unk_id = src_tokenizer.token_to_id("[UNK]")self.src_sos_id = src_tokenizer.token_to_id("[SOS]")self.src_eos_id = src_tokenizer.token_to_id("[EOS]")self.trg_unk_id = trg_tokenizer.token_to_id("[UNK]")self.trg_sos_id = trg_tokenizer.token_to_id("[SOS]")self.trg_eos_id = trg_tokenizer.token_to_id("[EOS]")# 验证关键ID是否存在self._validate_token_ids()def _validate_token_ids(self):for name, token_id in [("SRC_UNK", self.src_unk_id),("SRC_SOS", self.src_sos_id),("SRC_EOS", self.src_eos_id),("TRG_UNK", self.trg_unk_id),("TRG_SOS", self.trg_sos_id),("TRG_EOS", self.trg_eos_id)]:if token_id is None:raise ValueError(f"{name} token不存在于词汇表中")def __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]["translation"]# 中文编码(源语言)src_tokens = self._process_sequence(item["zh"], self.src_tokenizer,self.src_sos_id, self.src_eos_id, self.src_unk_id)# 英文编码(目标语言)trg_tokens = self._process_sequence(item["en"], self.trg_tokenizer,self.trg_sos_id, self.trg_eos_id, self.trg_unk_id)return {"src": torch.tensor(src_tokens),"trg": torch.tensor(trg_tokens)}def _process_sequence(self, text, tokenizer, sos_id, eos_id, unk_id):"""处理单个序列的编码"""encoded = tokenizer.encode(text)tokens = encoded.tokens[:self.max_len - 2] # 保留空间给SOS/EOS# 转换为ID,确保没有None值token_ids = []for t in tokens:token_id = tokenizer.token_to_id(t)token_ids.append(token_id if token_id is not None else unk_id)return [sos_id] + token_ids + [eos_id]def collate_fn(batch):src = [item["src"] for item in batch]trg = [item["trg"] for item in batch]return {"src": nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0),"trg": nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=0)}# 创建DataLoader
train_dataset = TranslationDataset(train_val["train"], zh_tokenizer, en_tokenizer)
val_dataset = TranslationDataset(train_val["test"], zh_tokenizer, en_tokenizer)BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)# Transformer模型实现
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=100):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_head, dropout=0.1):super().__init__()self.n_head = n_headself.d_k = d_model // n_headself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, q, k, v, mask=None):batch_size = q.size(0)# 线性变换并分头q = self.w_q(q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)k = self.w_k(k).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)v = self.w_v(v).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)# 计算注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn = F.softmax(scores, dim=-1)attn = self.dropout(attn)# 计算输出output = torch.matmul(attn, v)output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)return self.w_o(output), attnclass PositionwiseFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):super().__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.w_2(self.dropout(F.relu(self.w_1(x))))class EncoderLayer(nn.Module):def __init__(self, d_model, n_head, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_head, dropout)self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, mask):attn_output, _ = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout1(attn_output))ffn_output = self.ffn(x)x = self.norm2(x + self.dropout2(ffn_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, n_head, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_head, dropout)self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)def forward(self, x, encoder_output, src_mask, tgt_mask):attn_output, _ = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout1(attn_output))attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)x = self.norm2(x + self.dropout2(attn_output))ffn_output = self.ffn(x)x = self.norm3(x + self.dropout3(ffn_output))return xclass Encoder(nn.Module):def __init__(self, src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):super().__init__()self.token_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=0)self.pos_embed = PositionalEncoding(d_model, max_len)self.layers = nn.ModuleList([EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])self.dropout = nn.Dropout(dropout)def forward(self, src, src_mask):x = self.dropout(self.pos_embed(self.token_embed(src)))for layer in self.layers:x = layer(x, src_mask)return xclass Decoder(nn.Module):def __init__(self, trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len):super().__init__()self.token_embed = nn.Embedding(trg_vocab_size, d_model, padding_idx=0)self.pos_embed = PositionalEncoding(d_model, max_len)self.layers = nn.ModuleList([DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)])self.fc_out = nn.Linear(d_model, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, trg, encoder_output, src_mask, tgt_mask):x = self.dropout(self.pos_embed(self.token_embed(trg)))for layer in self.layers:x = layer(x, encoder_output, src_mask, tgt_mask)return self.fc_out(x)class Transformer(nn.Module):def __init__(self, src_vocab_size, trg_vocab_size, d_model=512, n_layers=6,n_head=8, d_ff=2048, dropout=0.1, max_len=100):super().__init__()self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)self.decoder = Decoder(trg_vocab_size, d_model, n_layers, n_head, d_ff, dropout, max_len)self.src_pad_idx = 0self.trg_pad_idx = 0def make_src_mask(self, src):return (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)def make_trg_mask(self, trg):trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)trg_len = trg.shape[1]trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()return trg_pad_mask & trg_sub_maskdef forward(self, src, trg):src_mask = self.make_src_mask(src)tgt_mask = self.make_trg_mask(trg[:, :-1])encoder_output = self.encoder(src, src_mask)output = self.decoder(trg[:, :-1], encoder_output, src_mask, tgt_mask)return output# 训练与评估
def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for batch in train_loader:src = batch["src"].to(device)trg = batch["trg"].to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output.reshape(-1, output_dim)trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for batch in val_loader:src = batch["src"].to(device)trg = batch["trg"].to(device)output = model(src, trg)output_dim = output.shape[-1]output = output.reshape(-1, output_dim)trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)# 初始化模型
model = Transformer(src_vocab_size=len(zh_vocab),trg_vocab_size=len(en_vocab),d_model=256, # 减小模型尺寸便于快速训练n_layers=3,n_head=4
).to(device)# 训练配置
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
CLIP = 1.0
N_EPOCHS = 20# 训练循环
for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)valid_loss = evaluate(model, val_loader, criterion)print(f'Epoch: {epoch + 1:02}')print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')# 翻译测试函数
def translate(model, sentence, src_tokenizer, trg_tokenizer, max_len=50):model.eval()# 编码源语言src_tokens = ["[SOS]"] + src_tokenizer.encode(sentence).tokens + ["[EOS]"]src_ids = [src_tokenizer.token_to_id(t) for t in src_tokens]src = torch.tensor(src_ids).unsqueeze(0).to(device) # [1, src_len]# 初始化目标序列(始终以SOS开头)trg_indexes = [trg_tokenizer.token_to_id("[SOS]")]# 逐步解码for _ in range(max_len):trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device) # [1, trg_len]with torch.no_grad():output = model(src, trg_tensor) # 形状应为 [1, trg_len, vocab_size]# 关键修正:安全获取最后一个预测tokenif output.size(1) == 0: # 处理空输出情况pred_token = trg_tokenizer.token_to_id("[UNK]")else:pred_token = output.argmax(-1)[0, -1].item() # 获取序列最后一个预测trg_indexes.append(pred_token)if pred_token == trg_tokenizer.token_to_id("[EOS]"):break# 转换为文本(跳过SOS和EOS)trg_tokens = []for i in trg_indexes[1:]: # 跳过初始的SOSif i == trg_tokenizer.token_to_id("[EOS]"):breaktrg_tokens.append(trg_tokenizer.id_to_token(i))return ' '.join(trg_tokens)# 测试翻译
test_sentences = ["你好世界","深度学习很有趣","今天天气真好"
]print("\n测试翻译结果:")
for sent in test_sentences:translation = translate(model, sent, zh_tokenizer, en_tokenizer)print(f"中文: {sent} -> 英文: {translation}")
输出为:
Epoch: 01Train Loss: 4.038 | Val. Loss: 3.263
Epoch: 02Train Loss: 3.184 | Val. Loss: 2.786
Epoch: 03Train Loss: 2.833 | Val. Loss: 2.497
Epoch: 04Train Loss: 2.612 | Val. Loss: 2.323
Epoch: 05Train Loss: 2.460 | Val. Loss: 2.205
Epoch: 06Train Loss: 2.352 | Val. Loss: 2.123
Epoch: 07Train Loss: 2.269 | Val. Loss: 2.055
Epoch: 08Train Loss: 2.206 | Val. Loss: 2.009
Epoch: 09Train Loss: 2.154 | Val. Loss: 1.971
Epoch: 10Train Loss: 2.110 | Val. Loss: 1.936
Epoch: 11Train Loss: 2.073 | Val. Loss: 1.910
Epoch: 12Train Loss: 2.041 | Val. Loss: 1.886
Epoch: 13Train Loss: 2.013 | Val. Loss: 1.866
Epoch: 14Train Loss: 1.988 | Val. Loss: 1.847
Epoch: 15Train Loss: 1.966 | Val. Loss: 1.833
Epoch: 16Train Loss: 1.946 | Val. Loss: 1.820
Epoch: 17Train Loss: 1.927 | Val. Loss: 1.805
Epoch: 18Train Loss: 1.910 | Val. Loss: 1.793
Epoch: 19Train Loss: 1.894 | Val. Loss: 1.786
Epoch: 20Train Loss: 1.880 | Val. Loss: 1.773测试翻译结果:
中文: 你好世界 -> 英文: [UNK] Hello THE , FUCK world . .
中文: 深度学习很有趣 -> 英文: [UNK] I of t f . unny
中文: 今天天气真好 -> 英文: [UNK] I weather t today o . n
四、注意力机制变体
1. 稀疏注意力实现
class SparseAttention(nn.Module):def __init__(self, block_size=32):super().__init__()self.block_size = block_sizedef forward(self, q, k, v):batch_size, seq_len, d_model = q.shape# 分块q = q.view(batch_size, -1, self.block_size, d_model)k = k.view(batch_size, -1, self.block_size, d_model)v = v.view(batch_size, -1, self.block_size, d_model)# 计算块内注意力scores = torch.einsum('bind,bjnd->bnij', q, k) / math.sqrt(d_model)attn = F.softmax(scores, dim=-1)output = torch.einsum('bnij,bjnd->bind', attn, v)# 恢复形状output = output.view(batch_size, seq_len, d_model)return output, attn
2. 相对位置编码
class RelativePositionEmbedding(nn.Module):def __init__(self, max_len=512, d_model=512):super().__init__()self.emb = nn.Embedding(2 * max_len - 1, d_model)self.max_len = max_lendef forward(self, q):"""q: [batch_size, seq_len, d_model]"""seq_len = q.size(1)range_vec = torch.arange(seq_len)distance_mat = range_vec[None, :] - range_vec[:, None] # [seq_len, seq_len]distance_mat_clipped = torch.clamp(distance_mat + self.max_len - 1, 0, 2 * self.max_len - 2)position_emb = self.emb(distance_mat_clipped) # [seq_len, seq_len, d_model]return position_emb
五、性能对比与总结
1.注意力模式可视化
def plot_attention(attention, source, target):fig = plt.figure(figsize=(10, 10))ax = fig.add_subplot(111)cax = ax.matshow(attention, cmap='bone')ax.set_xticklabels([''] + source, rotation=90)ax.set_yticklabels([''] + target)plt.show()
2. 关键演进规律
-
信息瓶颈突破:从固定长度上下文到动态注意力分配
-
计算效率提升:从RNN的O(n)序列计算到Transformer的并行化
-
建模能力增强:从局部依赖到全局关系建模
在下一篇文章中,我们将深入探讨归一化技术对比(BN/LN/IN/GN),分析不同归一化方法的特点和适用场景。