模型蒸馏demo
一、定义
-
bert 蒸馏demo
-
大模型蒸馏demo
二、实现
- demo
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import random
import torch_directml
# 设置随机种子保证可重现性
def set_seed(seed=42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)
set_seed()#介绍, tts2 情感分析数据集,正面\负面# 1. 超参数设置
class Config:# 模型设置teacher_model_name = "textattack/bert-base-uncased-SST-2" # 在SST-2上微调好的BERT作为教师student_model_name = "distilbert-base-uncased" # 小的DistilBERT作为学生# 训练参数batch_size = 16learning_rate = 5e-5num_epochs = 3max_length = 128# 蒸馏参数temperature = 5.0 # 温度参数alpha = 0.7 # KD损失权重,CE损失权重为 1-alpha# 设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = torch_directml.device()config = Config()
print(f"Using device: {config.device}")# 2. 加载和预处理数据
tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_name)def tokenize_function(examples):"""分词函数"""return tokenizer(examples["sentence"],padding="max_length",truncation=True,max_length=config.max_length,)# 加载SST-2数据集
print("Loading SST-2 dataset...")
dataset = load_dataset("glue", "sst2")
tokenized_datasets = dataset.map(tokenize_function, batched=True)# 重命名标签列并设置格式
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"])# 创建数据加载器
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=config.batch_size)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=config.batch_size)print(f"Train dataset size: {len(tokenized_datasets['train'])}")
print(f"Eval dataset size: {len(tokenized_datasets['validation'])}")# 3. 加载模型
print("Loading models...")
teacher_model = AutoModelForSequenceClassification.from_pretrained(config.teacher_model_name,use_safetensors = True
)
teacher_model = teacher_model.to(config.device)
teacher_model.eval() # 教师模型不训练student_model = AutoModelForSequenceClassification.from_pretrained(config.student_model_name,num_labels=2, # SST-2是二分类任务ignore_mismatched_sizes=True # 忽略head层不匹配的警告
).to(config.device)print(f"Teacher model: {config.teacher_model_name}")
print(f"Student model: {config.student_model_name}")
print(f"Number of teacher parameters: {sum(p.numel() for p in teacher_model.parameters()):,}")
print(f"Number of student parameters: {sum(p.numel() for p in student_model.parameters()):,}")# 4. 定义优化器和学习率调度器
optimizer = AdamW(student_model.parameters(), lr=config.learning_rate)
total_steps = len(train_dataloader) * config.num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps=total_steps
)# 5. 定义蒸馏损失函数
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):"""计算知识蒸馏损失"""# 软化教师和学生的输出soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)soft_student = F.log_softmax(student_logits / temperature, dim=-1)# KL散度损失 (知识蒸馏损失)kd_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)# 标准交叉熵损失ce_loss = F.cross_entropy(student_logits, true_labels)# 组合损失total_loss = alpha * kd_loss + (1 - alpha) * ce_lossreturn total_loss, kd_loss, ce_loss# 6. 评估函数
def evaluate_model(model, dataloader, desc="Evaluation"):"""评估模型准确率"""model.eval()total_correct = 0total_samples = 0with torch.no_grad():for batch in tqdm(dataloader, desc=desc):inputs = {k: v.to(config.device) for k, v in batch.items() if k != 'labels'}labels = batch['labels'].to(config.device)outputs = model(**inputs)predictions = torch.argmax(outputs.logits, dim=-1)total_correct += (predictions == labels).sum().item()total_samples += labels.size(0)accuracy = total_correct / total_samplesreturn accuracy# 7. 训练前的基准测试
print("\n=== 训练前基准测试 ===")
teacher_accuracy = evaluate_model(teacher_model, eval_dataloader, "Teacher Eval")
print(f"教师模型准确率: {teacher_accuracy:.4f}")# 初始学生模型准确率(随机权重)
initial_student_accuracy = evaluate_model(student_model, eval_dataloader, "Student Eval")
print(f"蒸馏前学生模型准确率: {initial_student_accuracy:.4f}")# 8. 训练循环(知识蒸馏)
print("\n=== 开始知识蒸馏训练 ===")
student_model.train()for epoch in range(config.num_epochs):total_loss = 0total_kd_loss = 0total_ce_loss = 0progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")for batch in progress_bar:# 准备数据input_ids = batch["input_ids"].to(config.device)attention_mask = batch["attention_mask"].to(config.device)labels = batch["labels"].to(config.device)# 清零梯度optimizer.zero_grad()# 教师模型前向传播(不计算梯度)with torch.no_grad():teacher_outputs = teacher_model(input_ids=input_ids,attention_mask=attention_mask)# 学生模型前向传播student_outputs = student_model(input_ids=input_ids,attention_mask=attention_mask)# 计算蒸馏损失loss, kd_loss, ce_loss = distillation_loss(student_outputs.logits,teacher_outputs.logits,labels,config.temperature,config.alpha)# 反向传播和优化loss.backward()optimizer.step()scheduler.step()# 记录损失total_loss += loss.item()total_kd_loss += kd_loss.item()total_ce_loss += ce_loss.item()# 更新进度条progress_bar.set_postfix({'Loss': f'{loss.item():.4f}','KD Loss': f'{kd_loss.item():.4f}','CE Loss': f'{ce_loss.item():.4f}'})# 计算每个epoch的平均损失avg_loss = total_loss / len(train_dataloader)avg_kd_loss = total_kd_loss / len(train_dataloader)avg_ce_loss = total_ce_loss / len(train_dataloader)print(f"\nEpoch {epoch+1} Summary:")print(f"Average Total Loss: {avg_loss:.4f}")print(f"Average KD Loss: {avg_kd_loss:.4f}")print(f"Average CE Loss: {avg_ce_loss:.4f}")# 每个epoch结束后在验证集上评估student_accuracy = evaluate_model(student_model, eval_dataloader, f"Eval Epoch {epoch+1}")print(f"学生模型准确率: {student_accuracy:.4f}")# 9. 最终评估
print("\n=== 最终评估 ===")
final_student_accuracy = evaluate_model(student_model, eval_dataloader, "Final Eval")
final_teacher_accuracy = evaluate_model(teacher_model, eval_dataloader, "Teacher Eval")print(f"\n{'='*50}")
print(f"{'模型':<20} {'参数量':<15} {'准确率':<10}")
print(f"{'-'*50}")
print(f"{'教师 (BERT)':<20} {sum(p.numel() for p in teacher_model.parameters()):<15,} {final_teacher_accuracy:.4f}")
print(f"{'学生 (初始)':<20} {'-':<15} {initial_student_accuracy:.4f}")
print(f"{'学生 (蒸馏后)':<20} {sum(p.numel() for p in student_model.parameters()):<15,} {final_student_accuracy:.4f}")
print(f"{'='*50}")# 10. 保存蒸馏后的模型
print("\nSaving distilled model...")
student_model.save_pretrained("./distilled_sst2_model")
tokenizer.save_pretrained("./distilled_sst2_model")
print("Model saved to './distilled_sst2_model'")# 11. 演示推理
def predict_sentiment(text, model, tokenizer):"""使用模型预测文本情感"""model.eval()inputs = tokenizer(text,padding="max_length",truncation=True,max_length=config.max_length,return_tensors="pt").to(config.device)with torch.no_grad():outputs = model(**inputs)probs = F.softmax(outputs.logits, dim=-1)prediction = torch.argmax(probs, dim=-1).item()confidence = probs[0][prediction].item()sentiment = "Positive" if prediction == 1 else "Negative"return sentiment, confidence# 测试一些样例
test_texts = ["This movie is absolutely fantastic!","I hated every minute of this film.","It was okay, nothing special.","A brilliant piece of cinema that everyone should see."
]print("\n=== 情感预测演示 ===")
for text in test_texts:sentiment, confidence = predict_sentiment(text, student_model, tokenizer)print(f"Text: {text}")print(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")print("-" * 60)
- 大模型蒸馏demo
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from datasets import load_dataset
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')# 1. 设置设备、超参数和模型名称
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 超参数
lr = 2e-5 # 指令微调通常使用较小的学习率
batch_size = 2 # 根据显存调整
num_epochs = 3
max_length = 512 # 最大序列长度# 模型名称 - 使用相同架构但不同初始化的模型
teacher_model_name = "Qwen/Qwen2-0.5B-Instruct"
student_model_name = "Qwen/Qwen2-0.5B" # 使用基础版本作为学生模型起点# 2. 加载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name, use_fast=True)
if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_tokenprint("Loading teacher model...")
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,torch_dtype=torch.float16,device_map="auto",
)
teacher_model.eval()print("Loading student model...")
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32,
)
student_model.to(device)
student_model.train()# 3. 准备优化器
optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)# 4. 加载真实的指令微调数据集
def load_instruction_dataset():"""加载指令微调数据集"""try:# 使用Alpaca格式的指令数据集dataset = load_dataset("yahma/alpaca-cleaned", split="train")print(f"Loaded dataset with {len(dataset)} examples")return datasetexcept:# 备用方案:使用简单的示例数据print("Using fallback example data")return [{"instruction": "解释机器学习","input": "","output": "机器学习是人工智能的一个分支,专注于开发能够从数据中学习并做出预测的算法。"},{"instruction": "写一首关于春天的诗","input": "","output": "春风拂面花香浓,万物复苏生机旺。蝴蝶翩翩舞花间,春天美景入画中。"}] * 20 # 重复以创建更多样本# 5. 数据预处理函数
def format_instruction(example):"""格式化指令数据"""if example['input']:text = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"else:text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"return {"text": text}def collate_fn(batch):"""批处理函数"""texts = [item['text'] for item in batch]inputs = tokenizer(texts,return_tensors="pt",padding=True,truncation=True,max_length=max_length,)# 创建标签(忽略instruction部分的损失)labels = inputs['input_ids'].clone()# 找到"### Response:"标记的位置,只计算response部分的损失response_token = tokenizer.encode("### Response:", add_special_tokens=False)[0]response_positions = (inputs['input_ids'] == response_token).nonzero()# 创建mask,只对response部分计算损失labels_mask = torch.zeros_like(labels)for pos in response_positions:batch_idx, start_idx = pos[0], pos[1]labels_mask[batch_idx, start_idx:] = 1# 将非response部分的标签设置为-100(忽略)labels[labels_mask == 0] = -100return {'input_ids': inputs['input_ids'].to(device),'attention_mask': inputs['attention_mask'].to(device),'labels': labels.to(device)}# 6. 准备数据加载器
dataset = load_instruction_dataset()
formatted_dataset = dataset.map(format_instruction) if hasattr(dataset, 'map') else [format_instruction(ex) for ex in dataset]# 创建简单的数据加载器
def data_loader(dataset, batch_size=2):for i in range(0, len(dataset), batch_size):batch = dataset[i:i+batch_size]yield collate_fn(batch)# 7. 蒸馏训练循环
for epoch in range(num_epochs):epoch_loss = 0.0num_batches = len(formatted_dataset) // batch_sizeprogress_bar = tqdm(data_loader(formatted_dataset, batch_size), desc=f"Epoch {epoch+1}/{num_epochs}", total=num_batches)for batch_idx, batch in enumerate(progress_bar):# 学生模型前向传播student_outputs = student_model(input_ids=batch['input_ids'],attention_mask=batch['attention_mask'],labels=batch['labels'],)student_loss = student_outputs.loss# 教师模型前向传播 (no_grad)with torch.no_grad():teacher_outputs = teacher_model(input_ids=batch['input_ids'],attention_mask=batch['attention_mask'],)teacher_logits = teacher_outputs.logits# 计算蒸馏损失 - 只对response部分student_logits = student_outputs.logits# 创建response部分的maskresponse_mask = (batch['labels'] != -100).float()# 计算KL散度损失distillation_loss = F.kl_div(F.log_softmax(student_logits / 2.0, dim=-1),F.softmax(teacher_logits / 2.0, dim=-1),reduction='none')distillation_loss = (distillation_loss * response_mask.unsqueeze(-1)).sum() / response_mask.sum()# 总损失 - 主要关注蒸馏损失total_loss = 0.7 * distillation_loss + 0.3 * student_loss# 反向传播和优化optimizer.zero_grad()total_loss.backward()torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0) # 梯度裁剪optimizer.step()epoch_loss += total_loss.item()progress_bar.set_postfix({"loss": f"{total_loss.item():.4f}"})if batch_idx >= num_batches: # 限制每个epoch的batch数量breakavg_loss = epoch_loss / min(num_batches, len(formatted_dataset) // batch_size)print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")# 8. 保存微调后的模型
output_dir = "./qwen2_instruction_tuned"
student_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Instruction-tuned model saved to {output_dir}")# 9. 测试指令遵循能力
student_model.eval()
test_instructions = ["解释人工智能的基本概念","写一个关于友谊的短故事","如何学习Python编程?"
]for instruction in test_instructions:prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"inputs = tokenizer(prompt, return_tensors="pt").to(device)with torch.no_grad():outputs = student_model.generate(**inputs,max_new_tokens=100,do_sample=True,temperature=0.7,top_p=0.9,pad_token_id=tokenizer.eos_token_id)generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)print(f"\nInstruction: {instruction}")print(f"Response: {generated_text[len(prompt):]}")print("-" * 50)