当前位置: 首页 > news >正文

通俗易懂的知识蒸馏

1️⃣ 背景介绍

大模型性能好,但是参数量大;而直接训练小模型(只有硬标签),性能往往不好;

因此提出知识蒸馏,其损失函数包括两部分:

  • 蒸馏损失:让小模型学习大模型的“软标签”【softmax输出后的结果,表示类别之间的相似性】,让小模型学习到类别之间的相似性,而不是只知道正确答案
  • 学生损失:硬标签,让小模型知道正确答案

2️⃣ 核心概念

1. 软标签 vs 硬标签

硬标签(Hard Labels)

图片是数字"3" → [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
  • 只告诉模型正确答案是什么
  • 没有额外信息

软标签(Soft Labels)

图片是数字"3" → [0.01, 0.01, 0.02, 0.90, 0.01, 0.01, 0.01, 0.01, 0.02, 0.01]
  • 告诉模型这主要是"3"
  • 但也有一点像"2"和"8"
  • 包含了类别之间的相似性信息!

2. 温度参数(Temperature)

温度参数T用于控制概率分布的"软化"程度:

# 正常的softmax
p_i = exp(z_i) / Σ exp(z_j)# 带温度的softmax
p_i = exp(z_i/T) / Σ exp(z_j/T)
  • T = 1:正常的概率分布
  • T > 1(如T=3):小概率类别的信息被放大,大概率类别的信息被缩小,概率分布更平滑,
  • T → ∞:接近均匀分布

为什么需要温度?

  • 正常的softmax输出往往是 [0.001, 0.997, 0.002, …],小概率类别的信息被"压制"
  • 提高温度后变成 [0.05, 0.85, 0.10, …],学生模型能学到更多信息

3. 蒸馏损失函数

总损失 = α × 蒸馏损失 + (1-α) × 学生损失# 蒸馏损失:让学生输出接近教师
distillation_loss = KL_divergence(student_soft, teacher_soft)# 学生损失:让学生能正确分类
student_loss = CrossEntropy(student_output, true_labels)
  • α:蒸馏损失的权重(通常0.5-0.9)
  • 1-α:学生损失的权重

3️⃣ 代码

"""
知识蒸馏 (Knowledge Distillation) 示例知识蒸馏是什么?
----------------
知识蒸馏是一种模型压缩技术,通过让一个小模型(学生模型)学习一个大模型(教师模型)的知识,
使得小模型能够达到接近大模型的性能,同时保持更小的模型尺寸和更快的推理速度。核心思想:
1. 教师模型(Teacher):一个已经训练好的大型、高性能模型
2. 学生模型(Student):一个较小的模型,我们希望它学习教师模型的知识
3. 软标签(Soft Labels):教师模型输出的概率分布,包含了类别之间的相似性信息
4. 温度(Temperature):用于软化概率分布,使得学生模型能学到更多信息蒸馏的优势:
- 软标签包含了比硬标签(one-hot)更多的信息
- 例如:对于手写数字识别,教师模型可能对"3"给出 [0.9, 0, 0.08, 0.02, ...]这告诉学生模型:"这主要是3,但看起来有点像8和2"
"""import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np# ==================== 1. 定义教师模型(大模型)====================
class TeacherModel(nn.Module):"""教师模型 - 一个较大的深度神经网络包含更多的层数和参数,性能更好但推理速度慢"""def __init__(self):super(TeacherModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)return x# ==================== 2. 定义学生模型(小模型)====================
class StudentModel(nn.Module):"""学生模型 - 一个较小的神经网络参数量少,推理速度快,但需要通过蒸馏来提升性能"""def __init__(self):super(StudentModel, self).__init__()self.conv1 = nn.Conv2d(1, 16, 3, 1)self.fc1 = nn.Linear(2704, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = torch.flatten(x, 1)x = self.fc1(x)return x# ==================== 3. 蒸馏损失函数 ====================
class DistillationLoss(nn.Module):"""知识蒸馏损失函数总损失 = α * 蒸馏损失 + (1-α) * 学生损失参数:temperature: 温度参数T,用于软化概率分布T越大,概率分布越平滑,学生模型能学到更多类别间的相似性alpha: 蒸馏损失的权重"""def __init__(self, temperature=3.0, alpha=0.7):super(DistillationLoss, self).__init__()self.temperature = temperatureself.alpha = alphaself.kl_div = nn.KLDivLoss(reduction='batchmean')self.ce_loss = nn.CrossEntropyLoss()def forward(self, student_logits, teacher_logits, labels):"""计算蒸馏损失Args:student_logits: 学生模型的输出(未经softmax)teacher_logits: 教师模型的输出(未经softmax)labels: 真实标签Returns:total_loss: 总损失"""# 1. 蒸馏损失:让学生模型的输出分布接近教师模型# 使用温度T来软化概率分布soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)soft_student = F.log_softmax(student_logits / self.temperature, dim=1)# KL散度衡量两个分布的差异distillation_loss = self.kl_div(soft_student, soft_teacher) * (self.temperature ** 2)# 2. 学生损失:让学生模型能够正确分类(传统的交叉熵损失)student_loss = self.ce_loss(student_logits, labels)# 3. 总损失:加权组合total_loss = self.alpha * distillation_loss + (1 - self.alpha) * student_lossreturn total_loss, distillation_loss, student_loss# ==================== 4. 训练函数 ====================
def train_teacher(model, device, train_loader, optimizer, epoch):"""训练教师模型"""model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Teacher Training Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')def train_student_with_distillation(student, teacher, device, train_loader, optimizer, criterion, epoch):"""使用知识蒸馏训练学生模型"""student.train()teacher.eval()  # 教师模型保持评估模式for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()# 获取学生模型输出student_output = student(data)# 获取教师模型输出(不计算梯度)with torch.no_grad():teacher_output = teacher(data)# 计算蒸馏损失total_loss, distill_loss, student_loss = criterion(student_output, teacher_output, target)total_loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Student Training Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\t'f'Total Loss: {total_loss.item():.4f} | 'f'Distill Loss: {distill_loss.item():.4f} | 'f'Student Loss: {student_loss.item():.4f}')def train_student_without_distillation(model, device, train_loader, optimizer, epoch):"""不使用蒸馏,直接训练学生模型(对比基准)"""model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Baseline Student Training Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')def test(model, device, test_loader, model_name="Model"):"""测试模型性能"""model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.cross_entropy(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f'\n{model_name} Test set: Average loss: {test_loss:.4f}, 'f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')return accuracy# ==================== 5. 主函数 ====================
def main():print("=" * 70)print("知识蒸馏 (Knowledge Distillation) 演示")print("=" * 70)print("\n本示例将演示:")print("1. 训练一个大的教师模型")print("2. 使用知识蒸馏训练一个小的学生模型")print("3. 对比:不使用蒸馏直接训练小模型")print("4. 比较三个模型的性能\n")# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备: {device}\n")# 数据加载transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 使用MNIST数据集train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST('./data', train=False, transform=transform)# 为了演示,只使用部分数据train_dataset = torch.utils.data.Subset(train_dataset, range(10000))test_dataset = torch.utils.data.Subset(test_dataset, range(2000))train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)# ==================== 步骤1: 训练教师模型 ====================print("=" * 70)print("步骤 1: 训练教师模型(大模型)")print("=" * 70)teacher = TeacherModel().to(device)teacher_optimizer = optim.Adam(teacher.parameters(), lr=0.001)teacher_epochs = 3for epoch in range(1, teacher_epochs + 1):train_teacher(teacher, device, train_loader, teacher_optimizer, epoch)teacher_acc = test(teacher, device, test_loader, "教师模型")# ==================== 步骤2: 使用蒸馏训练学生模型 ====================print("=" * 70)print("步骤 2: 使用知识蒸馏训练学生模型(小模型)")print("=" * 70)student_distilled = StudentModel().to(device)student_optimizer = optim.Adam(student_distilled.parameters(), lr=0.001)distillation_criterion = DistillationLoss(temperature=3.0, alpha=0.7)student_epochs = 3for epoch in range(1, student_epochs + 1):train_student_with_distillation(student_distilled, teacher, device, train_loader, student_optimizer, distillation_criterion, epoch)student_distilled_acc = test(student_distilled, device, test_loader, "学生模型(蒸馏)")# ==================== 步骤3: 不使用蒸馏训练学生模型(对比基准)====================print("=" * 70)print("步骤 3: 不使用蒸馏,直接训练学生模型(对比基准)")print("=" * 70)student_baseline = StudentModel().to(device)baseline_optimizer = optim.Adam(student_baseline.parameters(), lr=0.001)for epoch in range(1, student_epochs + 1):train_student_without_distillation(student_baseline, device, train_loader, baseline_optimizer, epoch)student_baseline_acc = test(student_baseline, device, test_loader, "学生模型(无蒸馏)")# ==================== 结果对比 ====================print("=" * 70)print("最终结果对比")print("=" * 70)print(f"教师模型(大模型)准确率:        {teacher_acc:.2f}%")print(f"学生模型(知识蒸馏)准确率:      {student_distilled_acc:.2f}%")print(f"学生模型(无蒸馏-基准)准确率:   {student_baseline_acc:.2f}%")print(f"\n知识蒸馏提升:                    {student_distilled_acc - student_baseline_acc:.2f}%")print("=" * 70)# 统计模型参数量def count_parameters(model):return sum(p.numel() for p in model.parameters())print("\n模型参数量对比:")print(f"教师模型参数量: {count_parameters(teacher):,}")print(f"学生模型参数量: {count_parameters(student_distilled):,}")print(f"参数压缩比: {count_parameters(teacher) / count_parameters(student_distilled):.2f}x")print("=" * 70)print("\n总结:")print("知识蒸馏使得小模型在保持轻量级的同时,性能接近大模型!")print("这就是知识蒸馏的魔力 - 让小模型学习大模型的'知识'而不是简单地缩小网络。")if __name__ == '__main__':main()

KL散度

对于两个离散概率分布 P(目标分布)和 Q(预测分布),KL 散度定义为:KL(P∥Q)=∑iP(i)⋅(log⁡P(i)−log⁡Q(i))\text{KL}(P \parallel Q) = \sum_{i} P(i) \cdot \left( \log P(i) - \log Q(i) \right)KL(PQ)=iP(i)(logP(i)logQ(i))也可写成:KL(P∥Q)=∑iP(i)⋅log⁡(P(i)Q(i))\text{KL}(P \parallel Q) = \sum_{i} P(i) \cdot \log \left( \frac{P(i)}{Q(i)} \right)KL(PQ)=iP(i)log(Q(i)P(i))含义:量化,用分布 Q 近似 P 时的信息损失,值越小表示两个分布越接近

PyTorch 中F.kl_div的计算公式:

output=∑itarget(i)⋅(log⁡(target(i))−input(i))\text{output} = \sum_{i} \text{target}(i) \cdot \left( \log(\text{target}(i)) - \text{input}(i) \right)output=itarget(i)(log(target(i))input(i))

对应上面的KL散度的式子,可以知道target为目标分布 P(i)P(i)P(i);input为 logQ(i)log Q(i)logQ(i)

这也解释了为什么计算软标签的时候,学生的软标签是log概率分布,而教师的软标签直接就是概率分布
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
distillation_loss = self.kl_div(soft_student, soft_teacher) * (self.temperature ** 2)
注意: 因为我们对 logits 除以了 T,这会导致梯度也缩小 T 倍。为了补偿这个效果,需要乘以 T²

交叉熵损失函数(Cross-Entropy Loss):

假设任务C个类别:

真实标签的概率分布为y\boldsymbol{y}yy\boldsymbol{y}y用独热编码表示,只有真实类别对应的位置才是1,其余地方都是0;例如,真实类别为第k类时,y=[0,0,...,1,...,0]\boldsymbol{y} = [0, 0, ..., 1, ..., 0]y=[0,0,...,1,...,0](第 k 位为 1)

模型预测的概率分布为y^\boldsymbol{\hat{y}}y^(通过 softmax 函数输出,满足∑i=1Cy^i=1\sum_{i=1}^C \hat{y}_i = 1i=1Cy^i=1),其中y^i\hat{y}_iy^i是预测为第i类的概率。

则多分类交叉熵损失的计算公式为:Loss=−∑i=1Cyi⋅log⁡(y^i)\text{Loss} = -\sum_{i=1}^C y_i \cdot \log(\hat{y}_i)Loss=i=1Cyilog(y^i)

由于真实标签是独热编码(仅真实类别 k 对应的 yk=1y_k = 1yk=1,其余 yi=0y_i = 0yi=0),上述公式可简化为:Loss=−log⁡(y^k)\text{Loss} = -\log(\hat{y}_k)Loss=log(y^k)其中,y^k\hat{y}_ky^k 是模型对真实类别的预测概率

而在python中,nn.CrossEntropyLoss 直接接收模型输出的原始 logits(未经过 softmax 的分数) 和真实标签,内部自动完成上述交叉熵的计算。

这也解释了为什么计算硬标签的时候,直接输入学生的logits
student_loss = self.ce_loss(student_logits, labels)


4️⃣ 知识点

知识蒸馏的精髓在于:

“教师不仅告诉学生答案,更重要的是教会学生如何思考”

通过软标签,学生模型能够学习到:

  • ✅ 类别之间的相似性
  • ✅ 决策边界的平滑性
  • ✅ 特征的重要性分布
  • ✅ 模型的"暗知识"(Dark Knowledge)

这就是为什么蒸馏后的小模型能够超越直接训练的小模型!


http://www.dtcms.com/a/558617.html

相关文章:

  • 免费发布信息网站有哪些建电子商城网站
  • 10.string(下)
  • 广东省建设监理协会信息管理网站wordpress 作者简介
  • tv电视盒子企业网站模板外贸网站的特色
  • 中国石油大学网页设计与网站建设免费做字体的网站
  • 解码LVGL基础
  • 延庆长沙网站建设综合服务门户网站建设
  • AOI在风电行业制造领域中的应用
  • 保健品网站dede模板网站制作咨询公司
  • oracl19c创建不带C##用户
  • 公司做网站如何跟客户介绍wordpress适合做商城吗
  • 商用网站开发计划书wordpress 技巧
  • 广州做网站制作网站建设笔记
  • 手机网站欢迎页面设计东莞建网站哪家强
  • 做网站需要多少钱卖片可以吗网站图片被盗连怎么办啊
  • 网站域名快速备案外贸网站建设公司
  • 网站建站域名解析最后做wordpress淘宝插件
  • 做自己的网站有什么用襄樊最好网站建设价格
  • DSBridge:在原生 WebView 中实现企业级 H5 ↔ Native 通信(支持异步 / 多次回调 / 命名空间)
  • win7 网站配置缅甸新闻最新消息
  • 学习FreeRTOS(软件定时器)
  • 网站索引量下降天津网站建设优化企业
  • 招牌做的好的网站上海app开发定制
  • 重庆建站公司价钱护肤品网站优化案例
  • Prometheus实战教程 03 - 主机监控
  • 建行网站查询密码是什么东西江门关键词优化公司
  • 开源企业网站建设系统中小企业有哪些公司
  • 金融投资网站方案精湛的中山网站建设
  • 人工智能训练师——2.1.1题解
  • 大学生网页设计与制作模板seo顾问服务福建