当前位置: 首页 > news >正文

【深入浅出】交叉熵损失函数——原理、公式与代码示例

1. 概述:为什么关注交叉熵?

在机器学习和深度学习的分类任务中,损失函数(Loss Function) 是衡量模型预测结果与真实标签之间差异的关键工具。在众多损失函数中,交叉熵损失(Cross-Entropy Loss) 凭借其优异的性能,成为了分类模型,特别是神经网络的首选。

本文将带你彻底搞懂交叉熵损失函数,从核心思想数学公式实际应用,并通过具体示例代码加深理解。

2. 核心思想:衡量概率分布的差异

交叉熵源于信息论,它衡量的是两个概率分布之间的差异

在分类任务中,这两个分布分别是:

  1. 真实分布 §:数据的真实标签,通常用独热编码(One-hot Encoding) 表示。例如,对于三分类问题,标签“猫”的独热编码是 [1, 0, 0]
  2. 预测分布 (Q):模型预测出的概率分布。例如,模型可能预测为 [0.7, 0.2, 0.1],表示它认为图像是“猫”、“狗”、“鸟”的概率分别为70%、20%和10%。

交叉熵损失的值越小,说明预测分布 Q 与真实分布 P 越接近。 因此,训练模型的终极目标就是最小化交叉熵损失

3. 数学公式与代码实现

交叉熵损失根据分类任务的不同,主要有两种形式。

3.1 二分类交叉熵损失 (Binary Cross-Entropy)

适用场景:只有两个类别的问题(如垃圾邮件分类、逻辑回归)。

公式
L=−1N∑i=1N[yi⋅log⁡(p(yi))+(1−yi)⋅log⁡(1−p(yi))]L = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot \log(p(y_i)) + (1 - y_i) \cdot \log(1 - p(y_i)) \right]L=N1i=1N[yilog(p(yi))+(1yi)log(1p(yi))]

  • NNN:样本数量。
  • yiy_iyi:第 iii 个样本的真实标签(0 或 1)。
  • p(yi)p(y_i)p(yi):模型预测该样本为正类 (yi=1y_i=1yi=1) 的概率

Python/PyTorch 实现

import torch
import torch.nn as nn# 真实标签 (1为正类,0为负类)
y_true = torch.tensor([1, 0, 1], dtype=torch.float32)
# 模型预测的概率值 (通常是sigmoid函数的输出)
y_pred = torch.tensor([0.9, 0.2, 0.4], dtype=torch.float32)# 手动实现
loss = - (y_true * torch.log(y_pred) + (1 - y_true) * torch.log(1 - y_pred))
loss = loss.mean()
print(f"手动计算二分类交叉熵损失: {loss:.4f}") # 输出示例: 0.3191# 使用PyTorch内置函数
bce_loss = nn.BCELoss()
loss_bce = bce_loss(y_pred, y_true)
print(f"PyTorch BCELoss: {loss_bce:.4f}") # 输出示例: 0.3191

3.2 多分类交叉熵损失 (Categorical Cross-Entropy)

适用场景:类别数大于两个的分类问题(如图像分类、情感分析)。

公式
L=−1N∑i=1N∑c=1Cyi,c⋅log⁡(p(yi,c))L = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \cdot \log(p(y_{i,c}))L=N1i=1Nc=1Cyi,clog(p(yi,c))

  • CCC:类别总数。
  • yi,cy_{i,c}yi,c:符号函数(样本 iii 的真实类别等于 ccc 则为 1,否则为 0)。
  • p(yi,c)p(y_{i,c})p(yi,c):模型预测样本 iii 属于类别 ccc 的概率。

关键点:由于真实标签是独热编码,只有一个位置是1,其他都是0。因此,这个公式的实质是只计算真实类别所对应的那个预测概率的对数

Python/PyTorch 实现

import torch
import torch.nn as nn# 真实标签(这里是类别索引,2表示第三个类别)
# 注意:PyTorch中通常不直接使用one-hot,而是用类别索引
y_true_index = torch.tensor([0, 2]) # 两个样本,真实类别分别是第0类和第2类# 模型的原始输出(logits,尚未经过Softmax)
logits = torch.tensor([[2.0, 0.5, 0.3],  # 第一个样本的logits[0.1, 1.0, 3.0]]) # 第二个样本的logits# 计算Softmax概率
softmax = nn.Softmax(dim=1)
probs = softmax(logits)
print("预测概率分布:\n", probs)
# 输出示例: 
# tensor([[0.7055, 0.2160, 0.0785],
#         [0.0447, 0.1214, 0.8338]])# 手动计算第一个样本的损失(真实类别为0)
# 只取真实类别0对应的概率0.7055计算-log
loss_manual = -torch.log(torch.tensor(0.7055))
print(f"手动计算第一个样本的损失: {loss_manual:.4f}") # 输出示例: 0.3488# 使用PyTorch内置函数(最常用!)
# nn.CrossEntropyLoss() = nn.LogSoftmax() + nn.NLLLoss()
# 输入:原始logits,标签:类别索引
ce_loss = nn.CrossEntropyLoss()
loss_ce = ce_loss(logits, y_true_index)
print(f"PyTorch CrossEntropyLoss: {loss_ce:.4f}") # 输出示例: 0.3488 (两个样本的平均)

4. 实例详解:通过例子直观理解

示例1:二分类(猫 vs. 狗)

场景真实标签 (y)预测概率 §损失计算损失值分析
自信正确1 (是猫)0.9−log⁡(0.9)-\log(0.9)log(0.9)~0.105损失很小,模型预测好
犹豫不决1 (是猫)0.5−log⁡(0.5)-\log(0.5)log(0.5)~0.693损失中等,模型不确定
自信错误1 (是猫)0.1−log⁡(0.1)-\log(0.1)log(0.1)~2.302损失巨大! 模型错得离谱

结论:交叉熵对“自信的错误”施加了非常严厉的惩罚,这迫使模型快速修正严重错误。

示例2:多分类(数字识别)

假设识别数字 0, 1, 2。真实数字是 “0”,独热编码为 [1, 0, 0]

场景预测概率 §损失计算分析
理想预测[0.9, 0.1, 0.0]−log⁡(0.9)≈0.105-\log(0.9) \approx 0.105log(0.9)0.105正确且自信
糟糕预测[0.1, 0.8, 0.1]−log⁡(0.1)≈2.302-\log(0.1) \approx 2.302log(0.1)2.302自信地犯错,损失巨大
保守预测[0.4, 0.3, 0.3]−log⁡(0.4)≈0.916-\log(0.4) \approx 0.916log(0.4)0.916正确但不自信,损失中等

5. 为什么是交叉熵?优势总结

  1. 梯度優雅,收斂迅速

    • 交叉熵损失函数关于模型参数的梯度计算非常简洁。
    • 对于使用 Softmax 的输出层,梯度可简化为 (预测概率 - 真实概率)
    • 误差越大,梯度越大,参数更新幅度越大,学习速度越快。这避免了使用均方误差(MSE)等损失函数可能带来的梯度消失问题。
  2. 惩罚机制合理

    • 强烈惩罚“ confidently wrong”(自信的错误预测),鼓励模型做出“ confident and correct”(自信且正确)的预测。

6. 最佳实践:如何应用

在深度学习框架中,交叉熵损失函数与输出层激活函数是黄金搭档:

  • 二分类任务 (Binary Classification)

    • 输出层激活函数Sigmoid(将输出压缩到 (0, 1) 区间,得到一个概率值)。
    • 损失函数nn.BCELoss (Binary Cross-Entropy Loss)。
  • 多分类任务 (Multi-class Classification)

    • 输出层激活函数Softmax(将输出压缩为概率分布,所有类别概率之和为1)。
    • 损失函数nn.CrossEntropyLoss注意:PyTorch中此函数已内置Softmax,因此网络的最后一层不需要再额外添加Softmax激活函数,直接输出Logits即可)。

7. 总结

特性描述
本质衡量真实分布与预测分布之间的差异
目标最小化损失,使预测分布逼近真实分布
类型二分类交叉熵 (BCE)、多分类交叉熵 (CE)
优点梯度计算高效、收敛快;对错误预测惩罚力度大
应用绝大多数分类模型(逻辑回归、神经网络、CNN等)
搭档Sigmoid(二分类)、Softmax(多分类)

一句话总结:交叉熵损失是引导分类模型从“错误”走向“正确”、从“不确定”走向“自信确定”的强大指挥棒。理解并掌握它,是构建高效分类模型的关键一步。

http://www.dtcms.com/a/392293.html

相关文章:

  • Vue实现路由守卫
  • Coze源码分析-资源库-删除工作流-前端源码-核心接口
  • 安踏集团 X OB Cloud:新零售创新如何有“底”和有“数”
  • Web3艺术品交易应用方案
  • Spring 事务管理详解:保障数据一致性的实践指南
  • 软考中级-软件设计师 答题解题思路
  • Java IDEA学习之路:第二周课程笔记归纳
  • SQL语句一文通
  • Ubuntu22.04 双显卡系统使用集显 DRM 渲染的完整流程记录
  • Coze源码分析-资源库-删除工作流-后端源码-IDL/API/应用/领域
  • MySQL库和表的操作语句
  • python、类
  • NumPy高级技巧:向量化、广播与einsum的高效使用
  • GD32VW553-IOT 基于 vscode 的 msdk 移植(基于Cmake)
  • Filter 过滤器详解与使用指南
  • 养成合成小游戏抖音快手微信小程序看广告流量主开源
  • 在 Ubuntu 系统下安装 Conda
  • ac8257 android 9 SYSTEM_LAST_KMSG
  • ARM 架构与嵌入式系统
  • ARM(14) - LCD(1)清屏和画图形
  • Linux第十九讲:传输层协议UDP
  • 计算机网络学习(四、网络层)
  • 开启科学计算之旅:《MATLAB程序设计》课程导览
  • MATLAB | 数学模型 | 传染病 SIR 模型的参数确定
  • MATLAB基本运算(2)
  • 小红书数据分析面试题及参考答案
  • SpringCloudStream:消息驱动组件
  • ret2text-CTFHub技能树
  • VirtualBox 7 虚拟机的硬盘如何扩大?
  • React新闻发布系统 权限列表开发