知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式对应
知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式对应
flyfish
KL散度的公式
KL散度用于衡量两个概率分布 PPP(教师分布)和 QQQ(学生分布)的差异,公式为:
KL(P∥Q)=∑xP(x)⋅[logP(x)−logQ(x)]
\text{KL}(P \parallel Q) = \sum_{x} P(x) \cdot \left[ \log P(x) - \log Q(x) \right]
KL(P∥Q)=x∑P(x)⋅[logP(x)−logQ(x)]
对应公式:
-
teacher_soft = softmax(teacher_logits / T, dim=-1)
→ 得到教师的概率分布 P(x)=softmax(teacher_logits/T)P(x) = \text{softmax}(\text{teacher\_logits}/T)P(x)=softmax(teacher_logits/T)。 -
student_soft = log_softmax(student_logits / T, dim=-1)
→ 得到学生的对数概率 logQ(x)=log(softmax(student_logits/T))\log Q(x) = \log\left( \text{softmax}(\text{student\_logits}/T) \right)logQ(x)=log(softmax(student_logits/T))。 -
kl_loss = sum( teacher_soft * (teacher_soft.log() - student_soft) ) / batch_size
teacher_soft.log()
是 logP(x)\log P(x)logP(x);
student_soft` 是 logQ(x)\log Q(x)logQ(x);
整体即公式中的 ∑P(x)⋅[logP(x)−logQ(x)]\sum P(x) \cdot [\log P(x) - \log Q(x)]∑P(x)⋅[logP(x)−logQ(x)],完全匹配KL散度的定义。
教师用softmax
是为了得到概率分布 P(x)P(x)P(x),学生用log_softmax
是为了直接得到 logQ(x)\log Q(x)logQ(x),两者组合恰好满足KL散度的公式要求,同时利用log_softmax
的数值稳定性提升计算可靠性。
log_softmax 操作在数学上等价于对输入先执行 softmax 得到概率分布,再对该概率分布取对数
import torch
import torch.nn.functional as F# 1. 定义示例输入(模型输出的logits)
logits = torch.tensor([[1.0, 2.0, 3.0], # 样本1的类别得分[4.0, 5.0, 6.0] # 样本2的类别得分
], dtype=torch.float32)# 温度参数(此处设为1.0,不影响等价性验证)
T = 1.0
scaled_logits = logits / T # 温度软化后的logits# 2. 两种方式计算对数概率
# 方式1:直接使用log_softmax
log_softmax_result = F.log_softmax(scaled_logits, dim=-1)# 方式2:先计算softmax,再取对数
softmax_result = F.softmax(scaled_logits, dim=-1)
log_of_softmax = torch.log(softmax_result)# 3. 打印结果对比
print("===== 原始logits(温度软化后) =====")
print(scaled_logits)
print("\n===== 方式1:log_softmax直接计算 =====")
print(log_softmax_result)
print("\n===== 方式2:softmax后取对数 =====")
print(log_of_softmax)# 4. 数值等价性验证(允许微小浮点数误差)
# 检查所有元素是否在1e-6精度内相等
is_equivalent = torch.allclose(log_softmax_result, log_of_softmax, atol=1e-6)
print("\n===== 等价性验证 =====")
print(f"log_softmax 与 softmax+log 是否等价:{is_equivalent}")
===== 原始logits(温度软化后) =====
tensor([[1., 2., 3.],[4., 5., 6.]])===== 方式1:log_softmax直接计算 =====
tensor([[-2.4076, -1.4076, -0.4076],[-2.4076, -1.4076, -0.4076]])===== 方式2:softmax后取对数 =====
tensor([[-2.4076, -1.4076, -0.4076],[-2.4076, -1.4076, -0.4076]])===== 等价性验证 =====
log_softmax 与 softmax+log 是否等价:True