KL Loss
背景
KL Loss主要监督的是模型输出分布 VS 目标分布 之间的相似性
它不直接监督位置、速度等数值,而是监督模型「认为哪种可能性更大」是否和目标一致。
在多模态预测、知识蒸馏、策略学习中尤为重要。
KL 散度主要监督什么?
项目 | 监督内容 | 应用场景 |
---|---|---|
分布相似性 | 模型输出的概率分布(预测) vs 目标分布(通常是软标签) | 知识蒸馏、轨迹分布、行为克隆等 |
不确定性建模 | 模型输出多个选择的分布(如多轨迹) vs 真值分布(soft target) | 轨迹预测、多模态输出 |
知识对齐 | 学生网络预测分布 vs 教师网络的 soft 分布 | 蒸馏 |
行为模仿/规划策略 | 模型生成的动作分布 vs 专家动作分布 | 模仿学习、策略学习 |
具体例子
- 知识蒸馏(Knowledge Distillation)
监督:
KL(Teacher(logits).softmax || Student(logits).softmax)
目标:让学生网络模仿教师网络输出的“概率分布”,而不是 hard label。
- 轨迹预测(Trajectory Prediction)
如果模型预测多种未来轨迹,每种轨迹有一个概率(例如多模态轨迹):predicted_probs = [0.6, 0.3, 0.1]
ground_truth_probs = [1.0, 0.0, 0.0] # one-hot or soft label from expertKL(predicted || ground_truth)
- 行为克隆(Behavior Cloning)/模仿学习
如果从专家(如人类或 rule-based agent)采样得到 soft policy 分布,模型输出 policy logits:
expert_policy = [0.7, 0.2, 0.1]
model_output = logits → softmax → [0.4, 0.4, 0.2]loss = KL(expert_policy || model_output)
目标:让模型模仿专家的策略分布(而不是只学最优动作)。
最基础的手写 KL 散度 loss (batch-wise)
假设:
p_target 是目标分布(通常来自 ground truth,已经是 soft label,如 one-hot 或 softmax)
q_pred 是模型输出分布(经过 softmax 或 log_softmax 之后)
import torch
import torch.nn.functional as Fdef kl_loss_manual(log_q, p):"""手动实现的KL散度:KL(p || q)参数:- log_q: 模型输出的对数概率分布(log_softmax后的)- p: 目标分布(soft label 或 one-hot)返回:- 平均 KL 散度 loss"""kl = p * (torch.log(p + 1e-10) - log_q) # 避免 log(0)return kl.sum(dim=-1).mean()
# 模拟一个 batch,有3个样本,每个是3类分类任务
logits = torch.tensor([[2.0, 1.0, 0.1],[1.5, 2.0, 0.5],[0.1, 0.2, 3.0]])# 模型输出的 log_softmax
log_q = F.log_softmax(logits, dim=1)# 假设目标是 one-hot(可以是 soft label)
p = torch.tensor([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]])loss = kl_loss_manual(log_q, p)
print("KL Loss:", loss.item())