当前位置: 首页 > news >正文

【Pytorch】分类问题交叉熵

1️⃣ 为什么分类问题不用 MSE(均方误差)?

表格

复制

场景标签预测MSE 损失
分类[0,0,1][0.3,0.3,0.4](0.4-1)²=0.36
分类[0,0,1][0.1,0.2,0.7](0.7-1)²=0.09

看起来合理,但:

  1. sigmoid/softmax 输出在 0/1 附近梯度几乎为零梯度消失

  2. MSE 把“概率差”当数值差” → 不符合概率直觉;

  3. 收敛慢,还容易卡在鞍点。


2️⃣ 交叉熵(Cross Entropy)思想

一句话:衡量「真实分布 p」与「预测分布 q」之间的信息差距

公式(离散版):

CE(p,q) = − Σ p(i) log q(i)

  • p 是 one-hot 标签(比如 [0,0,1])

  • q 是 softmax 输出(比如 [0.1,0.2,0.7])

因为 p 只有一个 1,其余为 0,所以求和只剩一项

CE = − log q(正确类)

直观

  • 若 q(正确类)=0.7 → CE ≈ 0.36

  • 若 q(正确类)=0.98 → CE ≈ 0.02
    预测越准,损失越小,且梯度不饱和(后面会算给你看)。


3️⃣ 手推一条二分类例子

表格

复制

样本真实 y预测 p
10.8
00.1

二元交叉熵(BCE):

L = − [y log p + (1−y) log(1−p)]

样本 1(猫):

L = − [1·log0.8 + 0·log0.2] = −log0.8 ≈ 0.223

样本 2(狗):

L = − [0·log0.1 + 1·log0.9] = −log0.9 ≈ 0.105

平均损失 ≈ 0.164,预测越离谱,值越大


4️⃣ PyTorch 一行代码算完

Python

import torch.nn.functional as Flogits = torch.tensor([[1.0, 2.0, 0.5]])   # 模型输出(未归一化)
target = torch.tensor([1])                  # 正确类别索引loss = F.cross_entropy(logits, target)
print(loss.item())          #  tensor(0.8309)

内部干了啥

  1. softmax(logits) → 概率

  2. log(softmax) → 对数概率

  3. -log q(正确类) → 损失


5️⃣ 数值稳定性技巧

不要手写:

Python

复制

prob = F.softmax(logits)
log_prob = torch.log(prob)
loss = F.nll_loss(log_prob, target)

推荐直接用:

Python

loss = F.cross_entropy(logits, target)

内部实现 log-sum-exp 技巧,避免 log(softmax) 造成数值溢出。


6️⃣ 对比实验(直观感受)

表格

复制

方法损失曲线梯度大小收敛速度
MSE平坦区早极小
CE无平坦区稳定

7️⃣ 小结口诀(背下来)

分类用 CE,回归用 MSE
CE = −log q(对类)
PyTorch:F.cross_entropy(logits, target)
别手写 softmax+log!


8️⃣ 课后 5 分钟动手

  1. F.cross_entropy 算一条三分类样本。

  2. logits 乘 10 再算一次,观察损失变化。

  3. 对比 F.mse_lossF.cross_entropy 的梯度大小(.grad)。

http://www.dtcms.com/a/486087.html

相关文章:

  • 如何轻松删除 realme 手机中的联系人
  • Altium Designer怎么制作自己的集成库?AD如何制作自己的原理图库和封装库并打包生成库文件?AD集成库制作好后如何使用丨AD集成库使用方法
  • Jackson是什么
  • 代码实例:Python 爬虫抓取与解析 JSON 数据
  • 襄阳建设网站首页百度知识营销
  • 山东住房和城乡建设厅网站电话开发软件都有哪些
  • AbMole| Yoda1( M9372;GlyT2-IN-1; Yoda 1)
  • LLM监督微调SFT实战指南(Qwen3-0.6B-Base)
  • 【基础算法】多源 BFS
  • *@UI 视角下主程序与子程序的菜单页面架构及关联设计
  • Virtio 半虚拟化技术解析
  • 网站设计怎么好看律师做网络推广哪个网站好
  • 用commons vfs 框架 替换具体的sftp 实现
  • 网站模板怎么设计软件wordpress多重筛选页面
  • 通往Docker之路:从单机到容器编排的架构演进全景
  • 分布式链路追踪:微服务可观测性的核心支柱
  • PostgreSQL 函数ARRAY_AGG详解
  • 【OpenHarmony】MSDP设备状态感知模块架构
  • RAG 多模态 API 处理系统设计解析:企业级大模型集成架构实战
  • 通过一个typescript的小游戏,使用单元测试实战(二)
  • 多物理域协同 + 三维 CAD 联动!ADS 2025 解锁射频前端、天线设计新体验
  • 前端微服务架构解析:qiankun 运行原理详解
  • linux ssh config详解
  • 内网攻防实战图谱:从红队视角构建安全对抗体系
  • 鲲鹏ARM服务器配置YUM源
  • 网站分类标准沈阳网站制作招聘网
  • 建设一个网站需要几个角色建筑工程网课心得体会
  • 基于Robosuite和Robomimic采集mujoco平台的机械臂数据微调预训练PI0模型,实现快速训练机械臂任务
  • 深度学习目标检测项目
  • SQL 窗口函数