当前位置: 首页 > news >正文

知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向

知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向

flyfish

知识蒸馏是什么?

知识蒸馏是一种广泛应用的模型压缩技术,旨在将大模型(教师模型)习得的知识迁移至小模型(学生模型)。大模型通常具备更多参数与知识表征能力,但其高容量也导致部署时的计算成本显著增加。知识蒸馏可将大模型蕴含的知识压缩至小模型中,核心思想是:小模型通过学习大模型的输出,能够提升自身性能。
知识蒸馏就是个 “教徒弟” 的技术:把大模型(师父,叫教师模型)学到的本事,教给小模型(徒弟,叫学生模型)。大模型参数多、本事大,但跑起来费电脑(计算成本高);小模型跑得轻快,可惜本事弱。知识蒸馏就是让小模型 “抄师父作业”—— 学大模型的输出,把师父的本事浓缩到自己身上,这样小模型也能变厉害。

知识蒸馏如何工作?

知识从教师模型向学生模型的迁移,通过在迁移数据集上的训练实现:学生模型被引导模仿教师模型输出的token 级概率分布(即模型对每个 token 类别的预测概率)。

简化形式展示知识蒸馏的工作流程

在这里插入图片描述

数据输入:同一批数据(Dataset)同时喂给 教师模型 和 学生模型。
模型推理:两者分别输出 logits(模型最后一层的原始输出,未归一化的概率)。
损失计算:通过损失函数(如前向 KL 散度),衡量 学生 logits 与 教师 logits 的差异。
权重更新:用损失反向传播,更新 学生模型 的权重,让学生逐渐 “模仿” 教师的输出。
迭代优化:重复上述过程,直到学生模型的输出足够接近教师模型。

前向 KL

“前向 KL” 中的 “前向”并非指“前向传播”,而是指 KL 散度的计算方向(即两个分布的“对比方向”)。

class ForwardKLLoss(torch.nn.Module):def __init__(self, ignore_index: int = -100):super().__init__()self.ignore_index = ignore_index  # 忽略的标签(如padding,不参与损失计算)def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:# 1. 教师logits → 概率分布(softmax归一化)teacher_prob = F.softmax(teacher_logits, dim=-1)  # 2. 学生logits → 对数概率(log_softmax,方便计算KL的对数项)student_logprob = F.log_softmax(student_logits, dim=-1)  # 3. 计算前向KL的核心项:teacher_prob * student_logprob(逐元素相乘后求和)prod_probs = teacher_prob * student_logprob  x = torch.sum(prod_probs, dim=-1).view(-1)  # 对“词表维度”求和,展平为样本维度# 4. 处理忽略标签:mask标记有效位置(labels != ignore_index的位置为1,否则0)mask = (labels != self.ignore_index).int().view(-1)  # 5. 计算平均损失:仅对有效位置求平均(避免忽略项干扰)loss = -torch.sum(x * mask) / torch.sum(mask)  return loss

1. 先理解 KL 散度的方向

KL 散度的定义是 DKL(P∥Q)D_{KL}(P \parallel Q)DKL(PQ),表示 用分布 P 作为“基准”,衡量分布 QP 的差异(可以理解为“从 P 看向 Q 的差异”)。

在知识蒸馏中:
P教师模型的输出分布teacher_prob = softmax(teacher_logits));
Q学生模型的输出分布student_prob = softmax(student_logits),代码中用 log_softmax 优化计算)。

因此,“前向 KL” 指 D_{KL}(教师分布 \parallel 学生分布),即强制学生分布向教师分布对齐(学生模仿教师)。

2. 前向传播

“前向传播”(Forward Pass)是 模型计算输出的过程(比如输入数据经过 Student/Teacher 模型,得到 logits 的过程,对应图中 Student model → Student logitsTeacher model → Teacher logits 的箭头)。

“前向 KL” 是损失函数的计算逻辑(两个分布的对比方向),和模型的前向传播是 完全不同的概念
前向传播是 模型推理的步骤(计算 logits);
前向 KL 是 损失的计算规则(衡量两个分布的差异方向)。

3. 为什么叫“前向”?

可以类比为 “信息传递的方向”
教师模型是“知识的提供者”(分布 P),学生模型是“知识的学习者”(分布 Q);
DKL(P∥Q)D_{KL}(P \parallel Q)DKL(PQ) 意味着 学生向教师对齐(学生模仿教师,知识从教师“流向”学生),因此称为“前向”(知识传递的方向)。

4. 反向 KL

如果反过来计算 DKL(Q∥P)D_{KL}(Q \parallel P)DKL(QP)Q是学生模型的生成分布,P是教师模型的分布。该式等价于最大化学生生成序列在教师分布下的对数概率,迫使学生生成 “教师认可的高质量文本”。
强化学习视角的优化
将反向 KL 与策略梯度(Policy Gradient)结合,将学生模型视为策略网络,教师模型的输出分布作为奖励信号。这种方法允许学生在生成过程中动态调整 token 选择,平衡多样性与质量。

ignore_index: int = -100

在PyTorch的损失函数中,ignore_index: int = -100 是一个常用参数,用于指定 需要被忽略的标签值——即当标签等于这个值时,对应的样本或位置不会参与损失计算,也不会影响模型的参数更新。

作用

在文本处理等序列任务中,输入序列通常会被填充(padding)到相同长度(比如用 <pad> 符号),这些填充符号的标签本身没有实际意义,不应该参与模型训练。
此时,会将填充符号的标签设为 -100,并通过 ignore_index=-100 告诉损失函数:忽略所有标签为 -100 的位置,不计算它们的损失。

在代码中的体现:

在的 ForwardKLLoss 中:

mask = (labels != self.ignore_index).int()  # 生成掩码:标签不是-100的位置为1,是-100的位置为0

这里通过掩码 mask 过滤掉了标签为 -100 的位置,确保这些位置的损失不会被计入最终结果,避免无效的填充符号干扰模型训练。

为什么是 -100

这是PyTorch的一个约定俗成的默认值(很多内置损失函数如 CrossEntropyLoss 也默认用 -100),主要是为了避免与实际标签值(通常从0开始)冲突,确保被忽略的标签不会被误判为有效标签。

http://www.dtcms.com/a/320350.html

相关文章:

  • 适配器模式及优化
  • 在NVIDIA Orin上用TensorRT对YOLO12进行多路加速并行推理时内存泄漏 (中)
  • linux系统编程
  • 使用winsw把SpringBoot项目注册成window服务
  • javaweb开发之会话_过滤器_监听器
  • 【感知机】感知机(perceptron)学习算法的收敛性
  • 【Unity3D实例-功能-镜头】第三人称视觉-镜头优化
  • 基于深度学习的污水新冠RNA测序数据分析系统
  • Linux机器可直接使用的自动化编译文件
  • AGV_ads通讯exe的创建
  • Java日志技术:从基础到实战
  • 蒙文OCR识别技术难点实现及应用场景剖析
  • Transformer:Attention is all you need
  • HCIP | BGP综合实验报告册
  • PMP项目管理:理解PMP、PMP学什么 / 适合谁学 / Project Management Professional / 项目管理专业人士
  • uat是什么
  • Day32--动态规划--509. 斐波那契数,70. 爬楼梯,746. 使用最小花费爬楼梯
  • 华为服务器如何部署Mindie镜像
  • 俄文识别技术,高精度识别,支持多场景多平台
  • 天猫商品评论API技术指南
  • 如何在NVIDIA H100 GPU上用Ollama以最高性能运行大语言模型
  • 2025数字马力一面面经(社)
  • 【2025最新版】火狐浏览器(官方版)安装-附教程
  • Ubuntu 22 下脚本登录MFA堡垒机
  • 一个自动定位并查询天气的工具(c语言)
  • 八股文智力题
  • 目标检测数据集 - 高架视角道路车辆检测数据集下载「包含VOC、COCO、YOLO三种格式」
  • 为什么会有反射
  • js中的设计模式
  • UnivNet论文分析(20210615)