当前位置: 首页 > news >正文

KL Loss

背景

KL Loss主要监督的是模型输出分布 VS 目标分布 之间的相似性
它不直接监督位置、速度等数值,而是监督模型「认为哪种可能性更大」是否和目标一致。
在多模态预测、知识蒸馏、策略学习中尤为重要。

KL 散度主要监督什么?

项目监督内容应用场景
分布相似性模型输出的概率分布(预测) vs 目标分布(通常是软标签)知识蒸馏、轨迹分布、行为克隆等
不确定性建模模型输出多个选择的分布(如多轨迹) vs 真值分布(soft target)轨迹预测、多模态输出
知识对齐学生网络预测分布 vs 教师网络的 soft 分布蒸馏
行为模仿/规划策略模型生成的动作分布 vs 专家动作分布模仿学习、策略学习

具体例子

  1. 知识蒸馏(Knowledge Distillation)

监督:


KL(Teacher(logits).softmax || Student(logits).softmax)

目标:让学生网络模仿教师网络输出的“概率分布”,而不是 hard label。

  1. 轨迹预测(Trajectory Prediction)
如果模型预测多种未来轨迹,每种轨迹有一个概率(例如多模态轨迹):predicted_probs = [0.6, 0.3, 0.1]
ground_truth_probs = [1.0, 0.0, 0.0]  # one-hot or soft label from expertKL(predicted || ground_truth)
  1. 行为克隆(Behavior Cloning)/模仿学习
    如果从专家(如人类或 rule-based agent)采样得到 soft policy 分布,模型输出 policy logits:
expert_policy = [0.7, 0.2, 0.1]
model_output = logits → softmax → [0.4, 0.4, 0.2]loss = KL(expert_policy || model_output)

目标:让模型模仿专家的策略分布(而不是只学最优动作)。

最基础的手写 KL 散度 loss (batch-wise)

假设:

p_target 是目标分布(通常来自 ground truth,已经是 soft label,如 one-hot 或 softmax)

q_pred 是模型输出分布(经过 softmax 或 log_softmax 之后)

import torch
import torch.nn.functional as Fdef kl_loss_manual(log_q, p):"""手动实现的KL散度:KL(p || q)参数:- log_q: 模型输出的对数概率分布(log_softmax后的)- p: 目标分布(soft label 或 one-hot)返回:- 平均 KL 散度 loss"""kl = p * (torch.log(p + 1e-10) - log_q)  # 避免 log(0)return kl.sum(dim=-1).mean()
# 模拟一个 batch,有3个样本,每个是3类分类任务
logits = torch.tensor([[2.0, 1.0, 0.1],[1.5, 2.0, 0.5],[0.1, 0.2, 3.0]])# 模型输出的 log_softmax
log_q = F.log_softmax(logits, dim=1)# 假设目标是 one-hot(可以是 soft label)
p = torch.tensor([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]])loss = kl_loss_manual(log_q, p)
print("KL Loss:", loss.item())
http://www.dtcms.com/a/361338.html

相关文章:

  • 生产者-消费者问题与 QWaitCondition
  • 深入探讨Java异常处理:受检异常与非受检异常的最佳实践
  • leetcode 1576 替换所有的问号
  • 深入Linux内核:IPC资源管理揭秘
  • Unity资源导入设置方式选择
  • 【Element Plus `el-select` 下拉菜单响应式定位问题深度解析】
  • 【数学建模学习笔记】缺失值处理
  • SRE 系列(五)| MTTK/MTTF/MTTV:故障应急机制的三板斧
  • 每周读书与学习->认识性能测试工具JMeter
  • 【开题答辩全过程】以 基于python爬虫对微博数据可视化及实现为例,包含答辩的问题和答案
  • Certificate is Signed Using a Weak Signature Algorithm漏洞解决
  • 从零到一,在GitHub上构建你的专属知识大脑:一个模块化RAG系统的开源实现
  • [VLDB 2025]阿里云大数据AI平台多篇论文被收录
  • 国别域名的SEO优势:是否更利于在当地搜索引擎排名?
  • 【赵渝强老师】阿里云大数据MaxCompute的体系架构
  • Midscenejs自然语言写测试用例
  • 设计模式在Android开发中的实战攻略(面试高频问题)
  • 基于STM32设计的宠物寄养屋控制系统(阿里云IOT)_276
  • 阿里云代理商:轻量应用服务器介绍及搭建个人博客教程参考
  • Shell 编程 —— 正则表达式与文本处理器
  • Shell脚本编程:函数、数组与正则表达式详解
  • 稳联技术的Profinet转Modbus转换网关与信捷PLC从站的连接配置进行了案例分析
  • Java全栈开发工程师面试实战:从基础到微服务的完整技术演进
  • 特征选择方法介绍
  • GPS:开启定位时代的科技魔杖
  • 趣味学RUST基础篇(String)
  • aws上创建jenkins
  • Pomian语言处理器研发笔记(三):使用组合子构建抽象语法树
  • 构建单页应用:React Router v6 核心概念与实战
  • Ubuntu22.04网络图标消失问题