知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向
知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向
flyfish
知识蒸馏是什么?
知识蒸馏是一种广泛应用的模型压缩技术,旨在将大模型(教师模型)习得的知识迁移至小模型(学生模型)。大模型通常具备更多参数与知识表征能力,但其高容量也导致部署时的计算成本显著增加。知识蒸馏可将大模型蕴含的知识压缩至小模型中,核心思想是:小模型通过学习大模型的输出,能够提升自身性能。
知识蒸馏就是个 “教徒弟” 的技术:把大模型(师父,叫教师模型)学到的本事,教给小模型(徒弟,叫学生模型)。大模型参数多、本事大,但跑起来费电脑(计算成本高);小模型跑得轻快,可惜本事弱。知识蒸馏就是让小模型 “抄师父作业”—— 学大模型的输出,把师父的本事浓缩到自己身上,这样小模型也能变厉害。
知识蒸馏如何工作?
知识从教师模型向学生模型的迁移,通过在迁移数据集上的训练实现:学生模型被引导模仿教师模型输出的token 级概率分布(即模型对每个 token 类别的预测概率)。
简化形式展示知识蒸馏的工作流程
数据输入:同一批数据(Dataset)同时喂给 教师模型 和 学生模型。
模型推理:两者分别输出 logits(模型最后一层的原始输出,未归一化的概率)。
损失计算:通过损失函数(如前向 KL 散度),衡量 学生 logits 与 教师 logits 的差异。
权重更新:用损失反向传播,更新 学生模型 的权重,让学生逐渐 “模仿” 教师的输出。
迭代优化:重复上述过程,直到学生模型的输出足够接近教师模型。
前向 KL
“前向 KL” 中的 “前向”并非指“前向传播”,而是指 KL 散度的计算方向(即两个分布的“对比方向”)。
class ForwardKLLoss(torch.nn.Module):def __init__(self, ignore_index: int = -100):super().__init__()self.ignore_index = ignore_index # 忽略的标签(如padding,不参与损失计算)def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:# 1. 教师logits → 概率分布(softmax归一化)teacher_prob = F.softmax(teacher_logits, dim=-1) # 2. 学生logits → 对数概率(log_softmax,方便计算KL的对数项)student_logprob = F.log_softmax(student_logits, dim=-1) # 3. 计算前向KL的核心项:teacher_prob * student_logprob(逐元素相乘后求和)prod_probs = teacher_prob * student_logprob x = torch.sum(prod_probs, dim=-1).view(-1) # 对“词表维度”求和,展平为样本维度# 4. 处理忽略标签:mask标记有效位置(labels != ignore_index的位置为1,否则0)mask = (labels != self.ignore_index).int().view(-1) # 5. 计算平均损失:仅对有效位置求平均(避免忽略项干扰)loss = -torch.sum(x * mask) / torch.sum(mask) return loss
1. 先理解 KL 散度的方向
KL 散度的定义是 DKL(P∥Q)D_{KL}(P \parallel Q)DKL(P∥Q),表示 用分布 P
作为“基准”,衡量分布 Q
与 P
的差异(可以理解为“从 P
看向 Q
的差异”)。
在知识蒸馏中:
P
是 教师模型的输出分布(teacher_prob = softmax(teacher_logits)
);
Q
是 学生模型的输出分布(student_prob = softmax(student_logits)
,代码中用 log_softmax
优化计算)。
因此,“前向 KL” 指 D_{KL}(教师分布 \parallel 学生分布)
,即强制学生分布向教师分布对齐(学生模仿教师)。
2. 前向传播
“前向传播”(Forward Pass)是 模型计算输出的过程(比如输入数据经过 Student/Teacher 模型,得到 logits
的过程,对应图中 Student model → Student logits
和 Teacher model → Teacher logits
的箭头)。
而 “前向 KL” 是损失函数的计算逻辑(两个分布的对比方向),和模型的前向传播是 完全不同的概念:
前向传播是 模型推理的步骤(计算 logits);
前向 KL 是 损失的计算规则(衡量两个分布的差异方向)。
3. 为什么叫“前向”?
可以类比为 “信息传递的方向”:
教师模型是“知识的提供者”(分布 P
),学生模型是“知识的学习者”(分布 Q
);
DKL(P∥Q)D_{KL}(P \parallel Q)DKL(P∥Q) 意味着 学生向教师对齐(学生模仿教师,知识从教师“流向”学生),因此称为“前向”(知识传递的方向)。
4. 反向 KL
如果反过来计算 DKL(Q∥P)D_{KL}(Q \parallel P)DKL(Q∥P)Q是学生模型的生成分布,P是教师模型的分布。该式等价于最大化学生生成序列在教师分布下的对数概率,迫使学生生成 “教师认可的高质量文本”。
强化学习视角的优化
将反向 KL 与策略梯度(Policy Gradient)结合,将学生模型视为策略网络,教师模型的输出分布作为奖励信号。这种方法允许学生在生成过程中动态调整 token 选择,平衡多样性与质量。
ignore_index: int = -100
在PyTorch的损失函数中,ignore_index: int = -100
是一个常用参数,用于指定 需要被忽略的标签值——即当标签等于这个值时,对应的样本或位置不会参与损失计算,也不会影响模型的参数更新。
作用
在文本处理等序列任务中,输入序列通常会被填充(padding)到相同长度(比如用 <pad>
符号),这些填充符号的标签本身没有实际意义,不应该参与模型训练。
此时,会将填充符号的标签设为 -100
,并通过 ignore_index=-100
告诉损失函数:忽略所有标签为 -100
的位置,不计算它们的损失。
在代码中的体现:
在的 ForwardKLLoss
中:
mask = (labels != self.ignore_index).int() # 生成掩码:标签不是-100的位置为1,是-100的位置为0
这里通过掩码 mask
过滤掉了标签为 -100
的位置,确保这些位置的损失不会被计入最终结果,避免无效的填充符号干扰模型训练。
为什么是 -100
?
这是PyTorch的一个约定俗成的默认值(很多内置损失函数如 CrossEntropyLoss
也默认用 -100
),主要是为了避免与实际标签值(通常从0开始)冲突,确保被忽略的标签不会被误判为有效标签。