当前位置: 首页 > news >正文

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式对应

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 KL散度公式对应

flyfish

KL散度的公式

KL散度用于衡量两个概率分布 PPP(教师分布)和 QQQ(学生分布)的差异,公式为:
KL(P∥Q)=∑xP(x)⋅[log⁡P(x)−log⁡Q(x)] \text{KL}(P \parallel Q) = \sum_{x} P(x) \cdot \left[ \log P(x) - \log Q(x) \right] KL(PQ)=xP(x)[logP(x)logQ(x)]

对应公式:

  1. teacher_soft = softmax(teacher_logits / T, dim=-1)
    得到教师的概率分布 P(x)=softmax(teacher_logits/T)P(x) = \text{softmax}(\text{teacher\_logits}/T)P(x)=softmax(teacher_logits/T)

  2. student_soft = log_softmax(student_logits / T, dim=-1)
    得到学生的对数概率 log⁡Q(x)=log⁡(softmax(student_logits/T))\log Q(x) = \log\left( \text{softmax}(\text{student\_logits}/T) \right)logQ(x)=log(softmax(student_logits/T))

  3. kl_loss = sum( teacher_soft * (teacher_soft.log() - student_soft) ) / batch_size
    teacher_soft.log()log⁡P(x)\log P(x)logP(x)
    student_soft` 是 log⁡Q(x)\log Q(x)logQ(x)
    整体即公式中的 ∑P(x)⋅[log⁡P(x)−log⁡Q(x)]\sum P(x) \cdot [\log P(x) - \log Q(x)]P(x)[logP(x)logQ(x)],完全匹配KL散度的定义。

教师用softmax是为了得到概率分布 P(x)P(x)P(x),学生用log_softmax是为了直接得到 log⁡Q(x)\log Q(x)logQ(x),两者组合恰好满足KL散度的公式要求,同时利用log_softmax的数值稳定性提升计算可靠性。

log_softmax 操作在数学上等价于对输入先执行 softmax 得到概率分布,再对该概率分布取对数

import torch
import torch.nn.functional as F# 1. 定义示例输入(模型输出的logits)
logits = torch.tensor([[1.0, 2.0, 3.0],  # 样本1的类别得分[4.0, 5.0, 6.0]   # 样本2的类别得分
], dtype=torch.float32)# 温度参数(此处设为1.0,不影响等价性验证)
T = 1.0
scaled_logits = logits / T  # 温度软化后的logits# 2. 两种方式计算对数概率
# 方式1:直接使用log_softmax
log_softmax_result = F.log_softmax(scaled_logits, dim=-1)# 方式2:先计算softmax,再取对数
softmax_result = F.softmax(scaled_logits, dim=-1)
log_of_softmax = torch.log(softmax_result)# 3. 打印结果对比
print("===== 原始logits(温度软化后) =====")
print(scaled_logits)
print("\n===== 方式1:log_softmax直接计算 =====")
print(log_softmax_result)
print("\n===== 方式2:softmax后取对数 =====")
print(log_of_softmax)# 4. 数值等价性验证(允许微小浮点数误差)
# 检查所有元素是否在1e-6精度内相等
is_equivalent = torch.allclose(log_softmax_result, log_of_softmax, atol=1e-6)
print("\n===== 等价性验证 =====")
print(f"log_softmax 与 softmax+log 是否等价:{is_equivalent}")
===== 原始logits(温度软化后) =====
tensor([[1., 2., 3.],[4., 5., 6.]])===== 方式1:log_softmax直接计算 =====
tensor([[-2.4076, -1.4076, -0.4076],[-2.4076, -1.4076, -0.4076]])===== 方式2:softmax后取对数 =====
tensor([[-2.4076, -1.4076, -0.4076],[-2.4076, -1.4076, -0.4076]])===== 等价性验证 =====
log_softmax 与 softmax+log 是否等价:True
http://www.dtcms.com/a/313363.html

相关文章:

  • 文件拷贝-代码
  • Doris json_contains 查询报错
  • 数据结构总纲以及单向链表详解:
  • 【LeetCode刷题指南】--对称二叉树,另一颗树的子树
  • [创业之路-531]:知识、技能、技术、科学之间的区别以及它们对于职业的选择的指导作用?
  • 【OpenGL】LearnOpenGL学习笔记02 - 绘制三角形、矩形
  • 13-day10生成式任务
  • 基于MBA与BP神经网络分类模型的特征选择方法研究(Python实现)
  • 在ANSYS Maxwell中对永磁体无线充电进行建模
  • 【大模型核心技术】Agent 理论与实战
  • 【设计模式】5.代理模式
  • Manus AI与多语言手写识别
  • 什么是“痛苦指数”(Misery Index)?
  • 如何获取网页中点击按钮跳转后的链接呢
  • 在 Cursor 中设置浅色背景和中文界面
  • 抽奖系统中 Logback 的日志配置文件说明
  • 03.一键编译安装Redis脚本
  • 【MySQL】MySQL 中的数据排序是怎么实现的?
  • 深入理解流式输出:原理、应用与大模型聊天软件中的实现
  • 跨语言模型中的翻译任务:XLM-RoBERTa在翻译任务中的应用
  • python---python中的内存分配
  • AI Agent 重塑产业发展新格局
  • 联想笔记本安装系统之后一直转圈圈的问题了?无法正常进入到系统配置界面,原来是BIOS中的VMD问题
  • Autoswagger:揭露隐藏 API 授权缺陷的开源工具
  • 使用CMake构建项目的完整指南
  • [LINUX操作系统]shell脚本之循环
  • 【Qt】QObject::startTimer: Timers cannot be started from another thread
  • 如何玩转 Kubernetes K8S
  • 【QT】概述
  • 快速搭建一个非生产k8s环境