当前位置: 首页 > news >正文

T5(Text-to-Text Transfer Transformer)模型

模型介绍

T5(Text-to-Text Transfer Transformer)模型是基于 Transformer 架构的扩展,但在设计理念、任务适配和实现细节上与原始 Transformer 有显著区别。同时,“共享权重” 是 T5 等现代预训练模型中广泛使用的技术,在模型效率和性能迁移中发挥关键作用。
一、T5 与原始 Transformer 的核心区别
原始 Transformer(Vaswani et al., 2017)是基础架构,主要用于机器翻译等 “序列到序列”(seq2seq)任务;而 T5 是在 Transformer 基础上针对通用 NLP 任务设计的预训练模型,核心区别如下:

  1. 任务统一范式:Text-to-Text(文本到文本)
    原始 Transformer:针对特定任务设计输入输出格式(如机器翻译中,输入是源语言序列,输出是目标语言序列),不同任务需要单独调整输入输出结构。
    T5:将所有 NLP 任务统一为 “文本到文本” 格式—— 无论任务类型(分类、翻译、摘要、问答等),输入和输出都是自然语言文本。
    例如:
    文本分类:输入 “情感分析:这部电影太精彩了!”,输出 “正面”;
    机器翻译:输入 “翻译为法语:Hello world”,输出 “Bonjour le monde”;
    问答:输入 “问题:地球的卫星是什么?答案:”,输出 “月球”。
    这种设计让模型可以用同一套架构处理所有任务,简化了迁移学习流程。
  2. 预训练目标:Span Corruption(片段损坏)
    原始 Transformer:无预训练阶段,直接针对特定任务训练(如机器翻译模型在平行语料上端到端训练)。
    T5:通过大规模无标注文本预训练,核心目标是 “Span Corruption”:
    随机选择输入文本中的连续片段(“Span”),用一个特殊符号(如<extra_id_0>)替换,然后让模型预测被替换的片段内容。
    例如:输入 “我 [昨天] 去 [公园] 玩”→ 替换为 “我 < extra_id_0 > 去 < extra_id_1 > 玩”,模型需要预测<extra_id_0>对应 “昨天”,<extra_id_1>对应 “公园”。
    这种目标比 BERT 的 “单 token 掩码” 或 GPT 的 “自回归生成” 更灵活,能更好地学习上下文语义和长距离依赖。
  3. 架构细节差异
    原始 Transformer:标准的 “编码器 - 解码器” 结构,编码器输出上下文向量,解码器基于此生成目标序列;位置编码使用绝对位置编码(正弦余弦函数)。
    T5:
    基础架构仍是编码器 - 解码器,但支持多种变体(如 “仅解码器”“仅编码器”);
    位置编码使用相对位置编码(更适合长序列,避免绝对位置对长文本的限制);
    层归一化(Layer Norm)的位置调整(原始 Transformer 在残差连接后,T5 在残差连接前),训练稳定性更好。
  4. 适用场景
    原始 Transformer:更适合作为 “基础组件” 搭建特定任务模型(如机器翻译、文本生成),需针对任务定制输入输出和训练目标。
    T5:预训练后可通过 “微调” 快速适配几乎所有 NLP 任务,尤其适合需要跨任务迁移的场景(如低资源任务、多任务学习)。
    二、共享权重的作用
    “共享权重” 是指模型中不同模块(如嵌入层、输出层)使用同一套参数,T5 等现代模型广泛采用这一设计,核心作用如下:
  5. 减少参数数量,提升效率
    原始 Transformer 中,输入嵌入层(将 token 映射为向量)和输出投影层(将解码器输出映射到词汇表)是两个独立的线性层,参数数量为2×vocab_size×d_model(vocab_size为词汇表大小,d_model为隐藏层维度)。
    T5 中,嵌入层与输出层共享权重,参数数量减少为vocab_size×d_model,显著降低了模型参数量(尤其当词汇表较大时,如 T5 的词汇表含 32k token),节省内存和计算资源。
  6. 增强输入与输出的语义一致性
    嵌入层将 “输入 token” 映射到语义空间,输出层将 “模型输出的隐藏向量” 映射回 “输出 token”。共享权重强制这两个过程使用同一套 “语义映射规则”,使得输入和输出在语义空间中保持一致。
    例如:输入 token “猫” 的嵌入向量,与模型生成 “猫” 时输出层的权重是关联的,避免了输入和输出语义空间的割裂,有助于提升生成任务的准确性(如翻译、摘要)。
  7. 促进预训练与微调的迁移
    T5 的预训练阶段学习了嵌入层的语义表示,微调时共享权重能让输出层直接复用预训练学到的语义知识,无需重新学习 “隐藏向量→token” 的映射,加速微调收敛,尤其在小数据集上效果更明显。
  8. 避免过拟合
    减少参数数量本质上降低了模型的复杂度,在有限数据上训练时,更不容易过拟合,提升模型的泛化能力。
    总结
    T5 vs 原始 Transformer:T5 是基于 Transformer 的 “通用任务模型”,通过统一的文本到文本范式、更灵活的预训练目标和优化的架构,实现了跨任务的高效迁移;而原始 Transformer 是基础架构,更侧重特定 seq2seq 任务。
    共享权重的作用:核心是通过参数复用提升效率、增强语义一致性、促进知识迁移,是 T5 等模型在大规模预训练中保持高效性和泛化能力的关键设计。

代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
from collections import defaultdict
import redef plot_loss_curve(losses, title="Training Loss Curve"):"""绘制训练损失曲线"""plt.figure(figsize=(10, 6))plt.plot(range(1, len(losses) + 1), losses, marker='o', linestyle='-', color='#1f77b4')plt.xlabel("Epoch")plt.ylabel("Average Loss")plt.title(title)plt.grid(True, alpha=0.3)plt.show()# -------------------------- 1. T5核心模块 --------------------------
class T5PositionalEncoding(nn.Module):"""T5使用的相对位置编码"""# 修正:相对位置偏差应基于注意力头数,而非模型维度def __init__(self, nhead, max_len=5000):super().__init__()self.nhead = nhead  # 每个注意力头有独立的相对位置偏差self.max_len = max_len# 相对位置编码参数:嵌入维度改为注意力头数nheadself.relative_attention_bias = nn.Embedding(2 * max_len - 1, nhead)def forward(self, seq_len_q, seq_len_k, device):"""计算相对位置偏差,返回形状为[seq_len_q, seq_len_k, nhead]"""range_vec_q = torch.arange(seq_len_q, device=device)range_vec_k = torch.arange(seq_len_k, device=device)distance_mat = range_vec_k[None, :] - range_vec_q[:, None]  # [seq_len_q, seq_len_k]distance_mat_clamped = torch.clamp(distance_mat, -self.max_len + 1, self.max_len - 1)final_mat = distance_mat_clamped + self.max_len - 1  # 偏移到非负索引return self.relative_attention_bias(final_mat)class T5FeedForward(nn.Module):"""T5的前馈网络,使用SwiGLU激活函数"""def __init__(self, d_model, ffn_dim, dropout=0.1):super().__init__()self.w1 = nn.Linear(d_model, ffn_dim)self.w2 = nn.Linear(ffn_dim, d_model)self.w3 = nn.Linear(d_model, ffn_dim)self.dropout = nn.Dropout(dropout)self.activation = nn.SiLU()  # SwiGLU中的Sigmoid加权线性单元def forward(self, x):return self.dropout(self.w2(self.activation(self.w1(x)) * self.w3(x)))class T5Attention(nn.Module):"""T5的注意力机制,支持自注意力和交叉注意力"""def __init__(self, d_model, nhead, dropout=0.1, is_cross_attention=False):super().__init__()self.d_model = d_modelself.nhead = nheadself.head_dim = d_model // nheadself.is_cross_attention = is_cross_attention# 确保维度可分assert self.head_dim * nhead == d_model, "d_model must be divisible by nhead"# 线性投影层self.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.out_proj = nn.Linear(d_model, d_model)# 修正:相对位置编码器使用注意力头数nhead初始化self.pos_encoder = T5PositionalEncoding(nhead)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):batch_size, seq_len_q, _ = query.size()seq_len_k = key.size(1)# 线性投影并分多头q = self.q_proj(query).view(batch_size, seq_len_q, self.nhead, self.head_dim).transpose(1, 2)k = self.k_proj(key).view(batch_size, seq_len_k, self.nhead, self.head_dim).transpose(1, 2)v = self.v_proj(value).view(batch_size, seq_len_k, self.nhead, self.head_dim).transpose(1, 2)# 计算注意力分数attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)# 修正:获取相对位置偏差并调整形状# 1. 生成相对位置偏差 [seq_len_q, seq_len_k, nhead]relative_bias = self.pos_encoder(seq_len_q, seq_len_k, query.device)# 2. 扩展到batch维度 [batch_size, seq_len_q, seq_len_k, nhead]relative_bias = relative_bias.unsqueeze(0).repeat(batch_size, 1, 1, 1)# 3. 调整维度顺序以匹配注意力分数 [batch_size, nhead, seq_len_q, seq_len_k]relative_bias = relative_bias.permute(0, 3, 1, 2)# 添加相对位置偏差(现在形状匹配)attn_scores += relative_bias# 应用掩码if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)# 注意力加权平均attn_probs = F.softmax(attn_scores, dim=-1)attn_probs = self.dropout(attn_probs)output = attn_probs @ v# 合并多头output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)output = self.out_proj(output)return output, attn_probs# 以下模块保持不变,但需要确保正确引用上述修改
class T5EncoderLayer(nn.Module):"""T5编码器层"""def __init__(self, d_model, nhead, ffn_dim, dropout=0.1):super().__init__()self.self_attn = T5Attention(d_model, nhead, dropout)self.ffn = T5FeedForward(d_model, ffn_dim, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, src, src_mask=None):# 自注意力子层src2, _ = self.self_attn(src, src, src, src_mask)src = src + self.dropout1(src2)src = self.norm1(src)# 前馈子层src2 = self.ffn(src)src = src + self.dropout2(src2)src = self.norm2(src)return srcclass T5DecoderLayer(nn.Module):"""T5解码器层"""def __init__(self, d_model, nhead, ffn_dim, dropout=0.1):super().__init__()self.self_attn = T5Attention(d_model, nhead, dropout)  # 解码器自注意力self.cross_attn = T5Attention(d_model, nhead, dropout, is_cross_attention=True)  # 编码器-解码器注意力self.ffn = T5FeedForward(d_model, ffn_dim, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):# 解码器自注意力tgt2, _ = self.self_attn(tgt, tgt, tgt, tgt_mask)tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)# 编码器-解码器交叉注意力tgt2, _ = self.cross_attn(tgt, memory, memory, memory_mask)tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)# 前馈子层tgt2 = self.ffn(tgt)tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgtclass T5Model(nn.Module):"""T5模型主体"""def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6,ffn_dim=2048, dropout=0.1, max_seq_len=5000):super().__init__()self.d_model = d_modelself.vocab_size = vocab_size# 嵌入层self.embedding = nn.Embedding(vocab_size, d_model)# 编码器self.encoder_layers = nn.ModuleList([T5EncoderLayer(d_model, nhead, ffn_dim, dropout)for _ in range(num_layers)])# 解码器self.decoder_layers = nn.ModuleList([T5DecoderLayer(d_model, nhead, ffn_dim, dropout)for _ in range(num_layers)])# 输出层(共享嵌入权重)self.output_projection = nn.Linear(d_model, vocab_size, bias=False)self.output_projection.weight = self.embedding.weightself.init_weights()def init_weights(self):"""初始化权重"""for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def encode(self, src, src_mask=None):"""编码器前向传播"""src = self.embedding(src)  # [batch_size, seq_len, d_model]src = src * math.sqrt(self.d_model)  # 缩放嵌入for layer in self.encoder_layers:src = layer(src, src_mask)return srcdef decode(self, tgt, memory, tgt_mask=None, memory_mask=None):"""解码器前向传播"""tgt = self.embedding(tgt)  # [batch_size, seq_len, d_model]tgt = tgt * math.sqrt(self.d_model)  # 缩放嵌入for layer in self.decoder_layers:tgt = layer(tgt, memory, tgt_mask, memory_mask)return tgtdef forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):"""完整前向传播"""memory = self.encode(src, src_mask)tgt_output = self.decode(tgt, memory, tgt_mask, memory_mask)logits = self.output_projection(tgt_output)return logits# -------------------------- 2. T5特有的数据集和工具函数 --------------------------
class T5TextToTextDataset(Dataset):"""T5文本到文本数据集,将所有任务统一为输入-输出文本对"""def __init__(self, tokenizer, task_type='classification', num_samples=1000, max_length=50):self.tokenizer = tokenizerself.task_type = task_typeself.max_length = max_lengthself.samples = self._generate_samples(num_samples)def _generate_samples(self, num_samples):"""生成样本,模拟不同任务类型"""samples = []if self.task_type == 'classification':# 分类任务: 输入"分类:文本",输出"类别"categories = ["科技", "教育", "娱乐", "体育", "政治"]tech_texts = ["人工智能是未来的发展方向", "机器学习算法取得新突破", "量子计算研究获得进展"]edu_texts = ["新的教育方法提高了学习效率", "在线教育平台用户数量激增", "教师培训计划取得成效"]ent_texts = ["电影市场迎来新的高峰", "音乐颁奖典礼吸引全球关注", "小说畅销榜出现新面孔"]sport_texts = ["奥运会打破多项世界纪录", "足球联赛冠军诞生", "运动员训练方法革新"]pol_texts = ["新的政策将影响经济发展", "国际会议达成多项共识", "政府发布新的发展规划"]text_category = {"科技": tech_texts,"教育": edu_texts,"娱乐": ent_texts,"体育": sport_texts,"政治": pol_texts}for _ in range(num_samples):category = categories[torch.randint(0, len(categories), (1,)).item()]text = text_category[category][torch.randint(0, len(text_category[category]), (1,)).item()]# 构造T5风格的输入输出input_text = f"分类:{text}"output_text = categorysamples.append((input_text, output_text))elif self.task_type == 'translation':# 翻译任务: 输入"翻译:文本",输出"翻译结果"en_zh_pairs = [("Hello world", "你好世界"),("I love machine learning", "我喜欢机器学习"),("Natural language processing is interesting", "自然语言处理很有趣"),("Artificial intelligence will change the world", "人工智能将改变世界"),("Transformer models are powerful", "Transformer模型很强大")]for _ in range(num_samples):en, zh = en_zh_pairs[torch.randint(0, len(en_zh_pairs), (1,)).item()]input_text = f"翻译:{en}"output_text = zhsamples.append((input_text, output_text))elif self.task_type == 'summarization':# 摘要任务: 输入"摘要:长文本",输出"短摘要"text_summ_pairs = [("自然语言处理是人工智能的一个重要分支,它研究如何使计算机能够理解、解释和生成人类语言。近年来,随着深度学习技术的发展,自然语言处理取得了显著进步。","自然语言处理是人工智能的分支,近年因深度学习而进步显著。"),("Transformer模型是一种基于自注意力机制的神经网络架构,在自然语言处理领域取得了巨大成功。它能够有效捕捉文本中的长距离依赖关系。","Transformer模型基于自注意力机制,在NLP领域很成功。")]for _ in range(num_samples):text, summ = text_summ_pairs[torch.randint(0, len(text_summ_pairs), (1,)).item()]input_text = f"摘要:{text}"output_text = summsamples.append((input_text, output_text))return samplesdef __len__(self):return len(self.samples)def __getitem__(self, idx):input_text, output_text = self.samples[idx]# 编码输入和输出文本input_ids = self.tokenizer.encode(input_text, add_eos=True)target_ids = self.tokenizer.encode(output_text, add_eos=True)# 截断或填充到最大长度input_ids = input_ids[:self.max_length]target_ids = target_ids[:self.max_length]input_ids += [self.tokenizer.special_tokens["<PAD>"]] * (self.max_length - len(input_ids))target_ids += [self.tokenizer.special_tokens["<PAD>"]] * (self.max_length - len(target_ids))return torch.tensor(input_ids), torch.tensor(target_ids)def create_mask(src, tgt, pad_token_id):"""创建注意力掩码"""batch_size, src_len = src.size()tgt_len = tgt.size(1)# 源序列掩码src_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, src_len]# 目标序列掩码(防止关注未来位置)tgt_mask = (tgt != pad_token_id).unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, tgt_len]subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool()  # [tgt_len, tgt_len]tgt_mask = tgt_mask & subsequent_mask  # [batch, 1, tgt_len, tgt_len]# 编码器-解码器掩码memory_mask = (src != pad_token_id).unsqueeze(1).unsqueeze(2)  # [batch, 1, tgt_len, src_len]return src_mask, tgt_mask, memory_mask# -------------------------- 3. 训练与评估函数 --------------------------
def train_t5(model, train_loader, tokenizer, criterion, optimizer, device, epochs=10):"""训练T5模型"""model.to(device)model.train()train_losses = []pad_token_id = tokenizer.special_tokens["<PAD>"]for epoch in range(epochs):total_loss = 0.0for batch_idx, (src, tgt) in enumerate(train_loader):src = src.to(device)tgt = tgt.to(device)# 构建输入和目标(目标移位一位)tgt_input = tgt[:, :-1]  # 解码器输入:移除最后一个tokentgt_output = tgt[:, 1:]  # 解码器目标:移除第一个token# 创建掩码src_mask, tgt_mask, memory_mask = create_mask(src, tgt_input, pad_token_id)# 前向传播logits = model(src, tgt_input, src_mask, tgt_mask, memory_mask)# 计算损失(忽略PAD token)loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))mask = (tgt_output != pad_token_id).reshape(-1)loss = (loss * mask).sum() / mask.sum()# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * src.size(0)# 打印批次信息if (batch_idx + 1) % 10 == 0:print(f"Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")# 计算平均损失avg_loss = total_loss / len(train_loader.dataset)train_losses.append(avg_loss)print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {avg_loss:.4f}")return train_lossesdef evaluate_t5(model, val_loader, tokenizer, device):"""评估T5模型"""model.to(device)model.eval()total_loss = 0.0pad_token_id = tokenizer.special_tokens["<PAD>"]criterion = nn.CrossEntropyLoss(reduction="none")all_preds = []all_labels = []with torch.no_grad():for src, tgt in val_loader:src = src.to(device)tgt = tgt.to(device)tgt_input = tgt[:, :-1]tgt_output = tgt[:, 1:]src_mask, tgt_mask, memory_mask = create_mask(src, tgt_input, pad_token_id)logits = model(src, tgt_input, src_mask, tgt_mask, memory_mask)# 计算损失loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))mask = (tgt_output != pad_token_id).reshape(-1)loss = (loss * mask).sum() / mask.sum()total_loss += loss.item() * src.size(0)# 收集预测结果preds = torch.argmax(logits, dim=-1)all_preds.extend(preds.cpu().numpy().flatten())all_labels.extend(tgt_output.cpu().numpy().flatten())avg_loss = total_loss / len(val_loader.dataset)# 过滤PAD token计算准确率valid_indices = [i for i, label in enumerate(all_labels) if label != pad_token_id]valid_preds = [all_preds[i] for i in valid_indices]valid_labels = [all_labels[i] for i in valid_indices]accuracy = accuracy_score(valid_labels, valid_preds)return avg_loss, accuracydef generate_text(model, tokenizer, input_text, max_length=50, device='cpu'):"""使用T5模型生成文本"""model.to(device)model.eval()# 编码输入文本input_ids = tokenizer.encode(input_text, add_eos=True)input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)  # [1, seq_len]# 初始化生成序列(从<BOS>或空开始)generated = [tokenizer.special_tokens["<CLS>"]]  # 使用CLS作为起始tokenfor _ in range(max_length):generated_tensor = torch.tensor(generated).unsqueeze(0).to(device)# 创建掩码src_mask = (input_ids != tokenizer.special_tokens["<PAD>"]).unsqueeze(1).unsqueeze(2)tgt_mask = torch.tril(torch.ones((len(generated), len(generated)), device=device)).bool()tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)# 预测下一个tokenwith torch.no_grad():logits = model(input_ids, generated_tensor, src_mask, tgt_mask)# 获取最后一个位置的预测next_token_logits = logits[0, -1, :]next_token_id = torch.argmax(next_token_logits).item()# 添加到生成序列generated.append(next_token_id)# 如果生成了结束token,停止if next_token_id == tokenizer.special_tokens["<EOS>"]:break# 解码生成的文本generated_text = tokenizer.decode(generated, skip_special_tokens=True)return generated_text# -------------------------- 4. 自定义Tokenizer --------------------------
class CustomTokenizer:def __init__(self, vocab=None, max_vocab_size=10000):# 定义特殊tokenself.special_tokens = {"<PAD>": 0,  # 填充token"<UNK>": 1,  # 未知token"<EOS>": 2,  # 结束token"<CLS>": 3,  # 分类token"<SEP>": 4,  # 分隔token# T5风格的特殊token"<extra_id_0>": 5,"<extra_id_1>": 6,"<extra_id_2>": 7}# 停用词token集合self.stop_tokens = {"<STOP>", ".", "!", "?", ","}# 初始化词汇表self.vocab = self.special_tokens.copy() if vocab is None else vocabself.inv_vocab = {v: k for k, v in self.vocab.items()}self.max_vocab_size = max_vocab_sizeself.word_counts = defaultdict(int)def add_vocab(self, text):"""从文本中构建词汇表"""words = self._tokenize_text(text)for word in words:if word in self.stop_tokens:continueself.word_counts[word] += 1# 按词频排序并添加到词汇表sorted_words = sorted(self.word_counts.keys(), key=lambda x: -self.word_counts[x])for word in sorted_words:if word not in self.vocab and len(self.vocab) < self.max_vocab_size:self.vocab[word] = len(self.vocab)self.inv_vocab[len(self.vocab) - 1] = worddef encode(self, text, add_eos=True):"""将文本编码为token ID序列"""tokens = self._tokenize_text(text)token_ids = []for token in tokens:if token in self.stop_tokens:continuetoken_ids.append(self.vocab.get(token, self.special_tokens["<UNK>"]))if add_eos:token_ids.append(self.special_tokens["<EOS>"])return token_idsdef decode(self, token_ids, skip_special_tokens=False):"""将token ID序列解码为文本"""tokens = []for idx in token_ids:token = self.inv_vocab.get(idx, "<UNK>")if skip_special_tokens and token in self.special_tokens:continuetokens.append(token)return " ".join(tokens)def _tokenize_text(self, text):"""基础文本分词"""return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)# -------------------------- 5. 测试T5模型 --------------------------
if __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备:{device}")# 1. 初始化Tokenizer并构建词汇表print("初始化Tokenizer...")tokenizer = CustomTokenizer(max_vocab_size=5000)# 生成示例文本用于构建词汇表sample_texts = """自然语言处理 人工智能 机器学习 深度学习 神经网络 分类 科技 教育 娱乐 体育 政治翻译 Hello world 你好世界 machine learning 摘要 Transformer模型 自注意力机制奥运会 电影 音乐 政策 经济 量子计算 在线教育 自然语言处理是人工智能的一个重要分支近年来随着深度学习技术的发展自然语言处理取得了显著进步"""tokenizer.add_vocab(sample_texts)vocab_size = len(tokenizer.vocab)print(f"词汇表大小:{vocab_size}")# 2. 配置参数MODEL_CONFIG = {"vocab_size": vocab_size,"d_model": 256,"nhead": 4,"num_layers": 2,"ffn_dim": 512,"dropout": 0.1,"max_seq_len": 50}TRAIN_CONFIG = {"batch_size": 16,"epochs": 10,"lr": 1e-4}# 3. 创建数据集和数据加载器print("创建数据集...")train_dataset = T5TextToTextDataset(tokenizer, task_type='classification', num_samples=500, max_length=50)val_dataset = T5TextToTextDataset(tokenizer, task_type='classification', num_samples=100, max_length=50)train_loader = DataLoader(dataset=train_dataset,batch_size=TRAIN_CONFIG["batch_size"],shuffle=True,num_workers=0)val_loader = DataLoader(dataset=val_dataset,batch_size=TRAIN_CONFIG["batch_size"],shuffle=False,num_workers=0)# 4. 初始化并训练T5模型print("初始化T5模型...")model = T5Model(**MODEL_CONFIG)# 损失函数和优化器criterion = nn.CrossEntropyLoss(reduction="none")optimizer = torch.optim.Adam(model.parameters(), lr=TRAIN_CONFIG["lr"])# 训练模型print("开始训练T5模型...")train_losses = train_t5(model=model,train_loader=train_loader,tokenizer=tokenizer,criterion=criterion,optimizer=optimizer,device=device,epochs=TRAIN_CONFIG["epochs"])# 绘制损失曲线plot_loss_curve(train_losses, title="T5 Training Loss Curve")# 5. 评估模型print("评估模型...")val_loss, val_acc = evaluate_t5(model, val_loader, tokenizer, device)print(f"验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.4f}")# 6. 测试生成效果print("\n测试文本生成...")test_inputs = ["分类:人工智能是未来的发展方向","分类:新的教育方法提高了学习效率","分类:奥运会打破多项世界纪录"]for input_text in test_inputs:generated_text = generate_text(model, tokenizer, input_text, max_length=50, device=device)print(f"输入: {input_text}")print(f"输出: {generated_text}")print("-" * 50)# 7. 测试其他任务类型print("\n测试翻译任务...")translation_dataset = T5TextToTextDataset(tokenizer, task_type='translation', num_samples=300, max_length=50)translation_loader = DataLoader(dataset=translation_dataset,batch_size=TRAIN_CONFIG["batch_size"],shuffle=True,num_workers=0)# 微调模型用于翻译任务print("微调模型用于翻译任务...")trans_optimizer = torch.optim.Adam(model.parameters(), lr=TRAIN_CONFIG["lr"] / 10)trans_losses = train_t5(model=model,train_loader=translation_loader,tokenizer=tokenizer,criterion=criterion,optimizer=trans_optimizer,device=device,epochs=5)# 测试翻译效果print("\n测试翻译效果...")test_translations = ["翻译:Hello world","翻译:I love machine learning"]for input_text in test_translations:generated_text = generate_text(model, tokenizer, input_text, max_length=50, device=device)print(f"输入: {input_text}")print(f"输出: {generated_text}")print("-" * 50)
http://www.dtcms.com/a/400639.html

相关文章:

  • 贵州新站优化建站网站教程
  • 为网站开发asp.net 制作网站开发
  • 广州中学生网站制作网站建设成功案例
  • 外贸企业网站改版花都网站设计
  • 网站建设分为哪几个步骤作风建设网站首页
  • 提卡网站要怎么做哪个网站可有做投票搭建
  • 众划算网站开发广东网站设计公司电话
  • 网站数据库怎么恢复wdcp上传网站
  • 易联网站建设拍摄宣传片
  • 牧和邻宠物网站建设网站未备案可以上线吗
  • 免费做网站公司陕西做网站公司哪家好
  • 南京市住房与城乡建设局网站做网站制作挣钱吗
  • 网站开发与app开发的区别河南省住房城乡建设门户网站
  • 镇江网站设计制作wordpress企业网站DIY
  • 做外链网站有哪些宋朝网站应该怎么做
  • 广西住房和城乡建设厅培训中心官方网站网站免费源码下载
  • 山西网站建设运营公司无锡网站建设收费
  • 可以做问卷赚钱的网站浙江乐清新闻今天
  • 精品课程网站建设毕业设计论文互联网网站建设月总结
  • 网站建设公司 知道万维科技宣讲家网站 家风建设
  • wordpress 时间调用赣州seo排名
  • 建设网站如何进行网站备案百度网盘搜索引擎入口哪里
  • 网站开发没有完成 需要赔偿多少阿里云建站视频
  • 曼联vs恩波利比分沈阳百度推广优化
  • 歌曲《今天》多曲线三维表达
  • Ollama中的Modelfile文件的编写以及使用
  • 谷歌绘制的网站ui网页设计实训报告
  • 光流 | 基于光流算法的多目标跟踪技术
  • 男和男做的视频网站商城网站入驻系统
  • 设计师常用的图库网站网站维护一年多少费