当前位置: 首页 > news >正文

Python-深度学习--1交叉熵损失函数

交叉熵损失函数(Cross-Entropy Loss)是机器学习和深度学习中用于分类任务的核心损失函数,尤其适用于解决多类别分类问题。它通过衡量模型预测概率分布与真实标签分布之间的 “差异”,指导模型参数优化,最终让预测结果更接近真实情况。

一、核心原理:衡量概率分布的差异

交叉熵(Cross-Entropy)原本是信息论中的概念,用于量化两个概率分布之间的 “不相似程度”。在机器学习中:

  • 假设模型对样本的预测结果是一个概率分布(例如,对 “猫 / 狗 / 鸟” 三分类,预测为猫的概率 0.8、狗 0.1、鸟 0.1);
  • 样本的真实标签也是一个概率分布(例如,真实是猫,则分布为 [1, 0, 0],即 “独热编码”);
  • 交叉熵损失函数计算这两个分布的差异,差异越大,损失值越高;差异越小,损失值越低。

二、数学定义

1. 二分类场景(输出为 0 或 1)

对于二分类问题(如 “垃圾邮件识别”“疾病诊断”),模型通常通过sigmoid函数输出单个概率值(属于类别 1 的概率),交叉熵损失公式为:


L=−y⋅log(y^​)−(1−y)⋅log(1−y^​)

其中:

  • y 是真实标签(0 或 1);
  • y^​ 是模型预测的概率(0~1 之间);
  • log 是自然对数。

示例
若真实标签 y=1,模型预测 y^​=0.9,则损失 L=−1⋅log(0.9)−0⋅log(0.1)≈0.105(损失小,预测准确);
若模型预测 y^​=0.1,则损失 L=−1⋅log(0.1)≈2.303(损失大,预测错误)。

2. 多分类场景(输出为多个类别)

对于多分类问题(如 “手写数字识别”“图像分类”),模型通过softmax函数输出每个类别的概率(所有类别概率和为 1),交叉熵损失公式为:

其中:

  • C 是类别总数;
  • yi​ 是真实标签的第i个值(独热编码,只有真实类别为 1,其余为 0);
  • y^​i​ 是模型预测第i类的概率。

示例
3 分类问题中,真实标签 y=[1,0,0](属于第 1 类),模型预测 y^​=[0.8,0.1,0.1],则损失 L=−1⋅log(0.8)−0⋅log(0.1)−0⋅log(0.1)≈0.223(预测准确,损失小);
若模型预测 y^​=[0.1,0.8,0.1],则损失 L=−1⋅log(0.1)≈2.303(预测错误,损失大)。

三、为什么用交叉熵损失?

  1. 与概率天然契合:分类任务的输出本质是 “属于每个类别的概率”,交叉熵直接衡量概率分布的差异,物理意义明确。
  2. 梯度特性更优:相比平方损失(MSE),交叉熵在模型预测错误时(如预测概率接近 0 但真实标签为 1),会产生更大的梯度,加速模型参数更新(收敛更快)。
    • 例如,用平方损失时,当 y^​=0.1 而 y=1,梯度为 −(y−y^​)=−0.9;
    • 用交叉熵时,梯度为 −y^​y​=−0.11​=−10,更新幅度更大。
  3. 适配概率输出函数:与sigmoid(二分类)、softmax(多分类)函数配合时,导数计算更简洁,避免梯度消失问题。

四、代码实现(PyTorch 示例)

1. 二分类交叉熵(BCEWithLogitsLoss)

PyTorch 将sigmoid和交叉熵合并为BCEWithLogitsLoss,直接输入未经过sigmoid的原始输出(logits):

2. 多分类交叉熵(CrossEntropyLoss)

PyTorch 的CrossEntropyLoss内置softmax,直接输入 logits,标签用类别索引(无需独热编码):

代码如下:

import torch
import torch.nn as nn# 模型输出(未经过sigmoid,形状:[batch_size, 1])
logits = torch.tensor([[2.0], [0.5], [-1.0]])  # 假设3个样本
# 真实标签(0或1,形状:[batch_size, 1])
labels = torch.tensor([[1.0], [1.0], [0.0]])# 定义二分类交叉熵损失
criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, labels)
print(loss.item())  # 输出损失值(约0.412)#--------------------------------------------------------------------------
# 模型输出(未经过softmax,形状:[batch_size, num_classes])
logits = torch.tensor([[3.0, 1.0, 0.2],  # 3个类别,2个样本[0.5, 2.0, 0.3]])
# 真实标签(类别索引,形状:[batch_size])
labels = torch.tensor([0, 1])  # 第1个样本属于类0,第2个属于类1# 定义多分类交叉熵损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(loss.item())  # 输出损失值(约0.326)

运行结果:

五、总结

交叉熵损失函数是分类任务的 “标配”,其核心作用是量化预测概率与真实标签的差异,并通过梯度下降推动模型优化。它的优势在于:

  • 适配分类任务的概率输出特性;
  • 梯度更新更高效,收敛更快;
  • 同时支持二分类和多分类场景。
http://www.dtcms.com/a/315901.html

相关文章:

  • aab文件直接安装到手机
  • 西门子PLC基础指令4:置位指令 S、复位指令 R
  • 超越注意力机制
  • 【基础】第二篇 Java 类的定义、构造器、成员变量与方法详解
  • hyper-v实战系列:显卡虚拟化(GPU分区)--windows篇详解
  • SpringCloud学习-------Eureka详解
  • FPGA设计思想与验证方法学系列学习笔记003
  • 基于Django的计算机资源爬虫及可视化系统的设计与实现
  • 【motion】HumanML3D 的安装2:psbody-mesh安装成功
  • Django中的转发与重定向详解
  • 利用m0改造循迹模块处理笔记00
  • Windows 安装 RabbitMQ 消息队列超详细步骤(附加详细操作截屏)
  • kettle插件-kettle http post plus插件,轻松解决https post接口无法调用文件流下载问题
  • python的高校考研交流系统
  • FastAPI快速入门P2:与SpringBoot比较
  • ESP-idf框架下的HTTP服务器\HTML 485温湿度采集并长传
  • windows如何查限制每个用户占用的显存大小
  • Node.js高并发下的内存泄漏排查与解决实录
  • 微信小程序的合规检测
  • (LeetCode 每日一题) 3477. 水果成篮 II (暴力)
  • Pytorch-06 如何开启模型训练(模型训练工作流梳理)
  • 几乎不会存在Store Buffer中的指令不提交缓存的情况~
  • Rust进阶-part3-生命周期
  • AI的第一次亲密接触——你的手机相册如何认出你的猫?
  • JavaWeb(苍穹外卖)--学习笔记17(Websocket)
  • 【qt5_study】2.使用Qt Designer构造UI界面(信号与槽)
  • 使用ZYNQ芯片和LVGL框架实现用户高刷新UI设计系列教程(第二十四讲)
  • React18 严格模式下的双重渲染之谜
  • 【01】OpenCV C++实战篇——基于多项式插值的亚像素边缘定位算法
  • NLP自然语言处理 02 RNN及其变体