当前位置: 首页 > news >正文

【系列07】端侧AI:构建与部署高效的本地化AI模型 第6章:知识蒸馏(Knowledge Distillation

第6章:知识蒸馏(Knowledge Distillation)

在构建端侧AI模型时,我们常常面临一个两难的局面:一方面需要大模型的强大性能,另一方面又必须满足端侧设备对模型体积和计算效率的要求。知识蒸馏是一种优雅的解决方案,它允许我们用一个大型的、性能优越的“教师模型”来指导一个小型、高效的“学生模型”的学习,从而让学生模型在保持轻量化的同时,获得接近教师模型的性能。


什么是知识蒸馏?

知识蒸馏的核心思想是转移知识。它不是简单地让学生模型去学习标注好的“硬标签”(hard labels),而是让它去学习教师模型的“软标签”(soft labels)。

  • 硬标签:指数据集中明确的类别标签,例如一张图片是“猫”或“狗”。学生模型的目标是尽可能地预测出正确的硬标签。
  • 软标签:指教师模型对每个类别的预测概率分布。例如,教师模型不仅会预测图片是“猫”,还会给出“狗”的概率是0.05,“老虎”的概率是0.02。这个概率分布包含了比单一硬标签更丰富的知识,因为它体现了不同类别之间的相似性和关系。

知识蒸馏的原理就是通过损失函数让学生模型的预测概率分布尽可能地接近教师模型的预测概率分布。


如何用大模型(教师模型)指导小模型(学生模型)的学习

知识蒸馏的训练过程可以概括为以下步骤:

  1. 选择教师模型:首先,你需要一个已经训练好的、性能强大的模型,作为你的教师。这个模型通常非常大,不适合直接部署。
  2. 选择学生模型:然后,你需要一个更小、更简单的模型,它将作为你的学生。这个模型需要有足够的容量来学习教师的知识。
  3. 构建训练流程:在训练阶段,你需要同时运行教师模型和学生模型。
    • 将同一批数据输入给教师模型,得到其预测的软标签(概率分布)。
    • 将同一批数据输入给学生模型,得到其预测的概率分布
  4. 计算损失函数:知识蒸馏的损失函数通常由两部分组成:
    • 蒸馏损失(Distillation Loss):用于衡量学生模型的概率分布与教师模型的软标签之间的差异。通常使用KL散度(Kullback-Leibler divergence)来计算。
    • 学生损失(Student Loss):用于衡量学生模型与真实硬标签之间的差异。通常使用交叉熵损失。
  5. 联合优化:通过联合优化这两个损失函数,学生模型不仅学习了硬标签,还从教师模型那里“继承”了更深层次的模式和知识。

实践:构建一个学生网络,并用一个预训练好的教师模型进行蒸馏

下面是一个使用PyTorch进行知识蒸馏的简化代码示例。我们将使用一个预训练的ResNet18作为教师,并构建一个更简单的网络作为学生。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models# 1. 定义教师模型 (使用预训练的ResNet18)
teacher_model = models.resnet18(pretrained=True)
teacher_model.eval() # 确保教师模型处于评估模式# 2. 定义学生模型 (一个简单的全连接网络)
class StudentNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1000, 500)self.relu = nn.ReLU()self.fc2 = nn.Linear(500, 100)def forward(self, x):x = self.relu(self.fc1(x))x = self.fc2(x)return xstudent_model = StudentNet()# 3. 定义损失函数
# 这里我们用两个损失函数,一个用于蒸馏,一个用于学生自己的学习
distillation_loss = nn.KLDivLoss(reduction="batchmean")
student_loss = nn.CrossEntropyLoss()# 4. 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)# 5. 训练循环 (简化版)
# 假设我们有一个dataloder
# for inputs, labels in dataloader:
#     # 将数据输入教师模型
#     with torch.no_grad():
#         teacher_outputs = teacher_model(inputs)#     # 将数据输入学生模型
#     student_outputs = student_model(inputs)#     # 计算损失
#     # 温度T是一个超参数,用于平滑概率分布
#     T = 2.0 
#     loss_distillation = distillation_loss(
#         F.log_softmax(student_outputs / T, dim=1),
#         F.softmax(teacher_outputs / T, dim=1)
#     )#     # 硬标签损失
#     loss_student = student_loss(student_outputs, labels)#     # 联合损失,通常会给两个损失分配权重
#     alpha = 0.5
#     total_loss = alpha * loss_distillation + (1 - alpha) * loss_student#     # 反向传播和优化
#     optimizer.zero_grad()
#     total_loss.backward()
#     optimizer.step()

通过这样的训练流程,学生模型不仅学习了如何正确分类,还从教师模型的“软标签”中学习到了类别之间的微妙关系,从而在更小的体量下实现了更好的性能。

http://www.dtcms.com/a/359160.html

相关文章:

  • mit6.824 2024spring Lab3A Raft
  • 简说DDPM
  • C语言---零碎语法知识补充(队列、函数指针、左移右移、任务标识符)
  • 机器人控制器开发(底层模块)——rk3588s 的 CAN 配置
  • 码农特供版《消费者权益保护法》逆向工程指北——附源码级注释与异常处理方案
  • 人工智能训练师复习题目实操题2.2.1 - 2.2.5
  • 手表--带屏幕音响-时间制切换12/24小时
  • PS学习笔记
  • 【15】VisionMaster入门到精通——--通信--TCP通信、UDP通信、串口通信、PLC通信、ModBus通信
  • 计算机算术7-浮点基础知识
  • 面经分享--小米Java一面
  • 青年教师发展(中科院软件所-田丰)
  • Dify 从入门到精通(第 65/100 篇):Dify 的自动化测试(进阶篇)
  • MCP与A2A的应用
  • LightGBM(Light Gradient Boosting Machine,轻量级梯度提升机)梳理总结
  • 【AI工具】在 VSCode中安装使用Excalidraw
  • 【69页PPT】智慧工厂数字化工厂蓝图规划建设方案(附下载方式)
  • 基于 Kubernetes 的 Ollama DeepSeek-R1 模型部署
  • 内存管理(智能指针,内存对齐,野指针,悬空指针)
  • Java中Integer转String
  • 为什么企业需要项目管理
  • 安卓编程 之 线性布局
  • 树莓派4B 安装中文输入法
  • AtCoder Beginner Contest 421
  • Mysql 学习day 2 深入理解Mysql索引底层数据结构
  • 【开题答辩全过程】以 基于WEB的茶文化科普系统的设计与实现为例,包含答辩的问题和答案
  • 用简单仿真链路产生 WiFi CSI(不依赖专用工具箱,matlab实现)
  • 面试tips--MyBatis--<where> where 1=1 的区别
  • 如何查看Linux系统中文件夹或文件的大小
  • 【LeetCode - 每日1题】有效的数独