当前位置: 首页 > news >正文

交叉熵损失函数

交叉熵

在机器学习中,损失函数用于衡量模型预测值 (prediction) 与真实标签 (label) 之间的“差距”或“不一致程度”。交叉熵(Cross Entropy)就是这种差距的一种优秀度量方式,它源于信息论,后来成为分类任务(尤其是分类)中最主流的损失函数。
直观理解:交叉熵衡量的是两个概率分布之间的差异。
一个分布是真实标签的分布(通常是 one-hot 编码,如 [0, 0, 1, 0])。
另一个分布是模型预测出的概率分布(如 [0.1, 0.2, 0.6, 0.1])。
如果模型的预测概率分布与真实分布完全一致,那么交叉熵为 0,否则交叉熵的值会大于 0,且差异越大,交叉熵的值就越大。

二分类交叉熵损失函数

  • 二元交叉熵损失函数适用于二分类问题,样本标签为为二元值:0或者1
  • 用于将模型预测值和真实值之间的差异转化为一个标量值,从而衡量模型预测的准确性。
    计算公式:
    L=−1N∑i=1N[yilog⁡(yi^)+(1−yi)log⁡(1−yi^)]L = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y_i}) + (1-y_i) \log(1-\hat{y_i})]L=N1i=1N[yilog(yi^)+(1yi)log(1yi^)]
    其中:
    NNN是样本数量;
    yiy_iyi表示第iii个样本的真实标签:0或者1
    yi^\hat{y_i}yi^表示第iii个样本预测的概率

torch.nn.BCELoss()

nn.BCELoss() 是二元交叉熵损失函数的基本实现。

函数接口
torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
  • weight :用于样本加权的权重张量,给每个批次元素的损失手动设置一个权重。必须是一维张量,默认值为 None。
  • reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。
输入格式

Input (input): 模型输出的预测值,必须是经过 Sigmoid 激活函数后的概率值,即每个元素的范围必须在 [0, 1] 之间。
Target (target): 真实标签,其形状必须与 input 相同。元素的值可以是:0 或 1(标准的二分类标签)或者介于 0 和 1 之间的值(例如,软标签或概率值)

工作流程

  • 模型最后一层必须使用 nn.Sigmoid() 作为激活函数。
  • 将模型的原始输出(logits)转换为概率。
  • 将这些概率值和真实标签一起送入 nn.BCELoss() 进行计算

该类数值不稳定且操作繁琐。

import torch
import torch.nn as nn# 模型输出的是 logits(原始分数)
model_output_logits = torch.tensor([2.0, -1.0, 0.5])
# 真实标签
true_labels = torch.tensor([1., 0., 1.]) # 注意:target 必须是 float 类型
# 1. 手动应用 Sigmoid 得到概率
sigmoid = nn.Sigmoid()
pred_probs = sigmoid(model_output_logits) # tensor([0.8808, 0.2689, 0.6225])
# 2. 使用 BCELoss 计算损失
criterion = nn.BCELoss()
loss = criterion(pred_probs, true_labels)
print(f"Probabilities: {pred_probs}")
print(f"BCE Loss: {loss}")

torch.nn.BCEWithLogitsLoss()

将 Sigmoid 层和 BCELoss 组合成一个单独的、数值稳定的损失函数。是目前被推荐使用的标准做法。

torch.nn.BCEWithLogitsLoss(weight=None,size_average=None,                            reduce=None, reduction='mean',     pos_weight=None)
  • weight:用于对每个样本的损失值进行加权。默认值为 None。
  • reduction:指定如何计算损失值。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
  • pos_weight:比BCELoss增加的参数。用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本是负样本数量的 1/5,可以设置 pos_weight = torch.tensor([5]),让模型更关注正样本。
输入格式

Input (input): 模型的原始输出(logits),不需要经过任何激活函数。值可以是任意实数范围 (-inf, +inf)。
Target (target): 与 BCELoss() 的要求完全相同

工作流程

  • 模型最后一层不需要任何激活函数,直接输出 logits。
  • 将这些 logits 和真实标签直接送入 nn.BCEWithLogitsLoss()。
  • 该函数会在内部进行数值稳定的 Sigmoid + 交叉熵计算。
import torch
import torch.nn as nn# 模型直接输出 logits(无需Sigmoid)
model_output_logits = torch.tensor([2.0, -1.0, 0.5])
true_labels = torch.tensor([1., 0., 1.])# 使用 BCEWithLogitsLoss 计算损失(直接输入logits!)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(model_output_logits, true_labels)
print(f"BCEWithLogits Loss: {loss}")# --- 处理类别不平衡的例子 ---
# 假设正样本(标签为1)非常少,我们为其设置更高的权重
pos_weight = torch.tensor([3.0]) # 正样本的权重是负样本的3倍
criterion_balanced = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss_balanced = criterion_balanced(model_output_logits, true_labels)

多分类交叉熵损失函数

用于类别数 C > 2 的任务
L=−1N∑i=1N∑c=1C[yi,clog⁡(p(yi,c))]L = -\frac{1}{N} \sum_{i=1}^{N}\sum_{c=1}^{C}[y_{i,c} \log(p(y_{i,c}))]L=N1i=1Nc=1C[yi,clog(p(yi,c))]

  • C:类别总数
  • yi,cy_{i,c}yi,c:一个独热编码(one-hot)向量,如果样板iii的真实类别是ccc,则yi,c=1y_{i,c}=1yi,c=1,否则为0
  • p(yi,c)p(y_{i,c})p(yi,c):模型预测样板i属于类别c的概率。

注意:在多分类中,由于真实标签 yyy 是 one-hot 形式,其向量中只有一项为 1,其余为 0。因此,对于单个样本,求和公式实际上简化为只计算真实类别所对应的那个预测概率的对数
Li=−log⁡(p(trueclass)]L_i= - \log(p(true class)]Li=log(p(trueclass)]

torch.nn.CrossEntropyLoss()

函数接口
torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0
)
  • weight:手动指定每个类别的权重,用于处理类别不平衡的数据集。
    格式是一维向量(Tensor),其长度等于类别的数量C。如果某个类别样本数量少,重要性高,可以为其赋予较大的权重。通常可以设置为 weight = 1 / class_frequency(每个类别的频率的倒数)。
  • ignore_index:指定一个被忽略的标签值,该标签对应的样本将不参与损失计算。默认值-100。
  • reduction:指定损失输出的聚合方式。可选值如下:
    ‘none’:返回每个样本的损失值,形状大小[batch_size])
    ‘mean’ :默认值,所有样本损失的平均值。
    ‘sum’:所有样本损失的总和。
  • label_smoothing:标签平滑,默认值为0.0,不启用。这是防止模型过拟合、提升泛化能力的一种正则化技术。
  • size_average 和 reduce是旧版参数,为了向后兼容保留,不推荐使用。
输入格式

Input(预测值input)
形状:以下两种之一:

  • (N, C):其中 N 是批次大小(Batch Size),C 是类别数量。这是最常见的情况,比如一个图像分类批次。
  • (N, C, d_1, d_2, …, d_K):用于更高维度的输入,例如像素级分类(语义分割),其中 (d_1, d_2, …, d_K) 是每个样本的空间维度(如高度、宽度)。

内容:必须是未经softmax处理的原始分数,该函数会在内部自动计算 LogSoftmax 以获得数值稳定性。
(nn.CrossEntropyLoss() = nn.LogSoftmax(dim=1) + nn.NLLLoss()。直接在Logits上工作,不需要也不应该在网络最后一层添加 Softmax

Target (真实标签 target)
形状

  • 对应 Input 形状为 (N, C) 时:Target 的形状是 (N)。每个值是该样本的真实类别索引,取值范围在 [0, C-1]。
  • 对应 Input 形状为 (N, C, d_1, d_2, …, d_K) 时:Target 的形状是 (N, d_1, d_2, …, d_K)。每个位置的值是该点的真实类别索引。

内容:是类别的索引,而不是 One-Hot 编码。

import torch.nn as nn
# 模型输出:2个样本,3个类别(Logits)
inputs = torch.tensor([[ 3.2, -1.5,  0.8],  # 样本1[ 0.4,  2.1, -0.1]]) # 样本2
# 正确:真实标签是类别索引
targets = torch.tensor([0, 2])  # 样本1的真实类别是0,样本2的真实类别是2# 错误:真实标签是One-Hot编码
# targets = torch.tensor([[1, 0, 0],
#                        [0, 0, 1]])criterion = nn.CrossEntropyLoss()
loss = criterion(inputs, targets)
print(loss)

文章转载自:

http://VikAVjm4.nnprp.cn
http://oFdle77B.nnprp.cn
http://tdZcPgOh.nnprp.cn
http://qYFHnwHY.nnprp.cn
http://4XA59mFH.nnprp.cn
http://48Wz8oJ6.nnprp.cn
http://DqvF2buD.nnprp.cn
http://BjI2M1OP.nnprp.cn
http://PHqWgcXd.nnprp.cn
http://pioHDxuL.nnprp.cn
http://hkckr8vi.nnprp.cn
http://0lTLvSmb.nnprp.cn
http://Amcx9oQq.nnprp.cn
http://oG7QcHIO.nnprp.cn
http://CM3twDwM.nnprp.cn
http://orWelW93.nnprp.cn
http://HTW67Fak.nnprp.cn
http://soiWYeAq.nnprp.cn
http://9XiD1Vki.nnprp.cn
http://WEQxdQnm.nnprp.cn
http://Bn0V4y2W.nnprp.cn
http://zQMJh7yt.nnprp.cn
http://YnyxSzFh.nnprp.cn
http://apHjpM8E.nnprp.cn
http://RqOrVbv9.nnprp.cn
http://nlv6fKuM.nnprp.cn
http://JQMevlx8.nnprp.cn
http://i1Vd5x5G.nnprp.cn
http://Lbe7IE8C.nnprp.cn
http://5emWuAkV.nnprp.cn
http://www.dtcms.com/a/363610.html

相关文章:

  • 一文读懂 Python 【循环语句】:从基础到实战,效率提升指南
  • 零构建的快感!dagger.js 与 React Hooks 实现对比,谁更优雅?
  • 餐饮、跑腿、零售多场景下的同城外卖系统源码扩展方案
  • 基于高德地图实现后端传来两点坐标计算两点距离并显示
  • JDK16安装步骤及下载(附小白详细教程)
  • 【Spring Cloud微服务】9.一站式掌握 Seata:架构设计与 AT、TCC、Saga、XA 模式选型指南
  • Javascript》》JS》》ES6》 Map、Set、WeakSet、WeakMap
  • Java 技术支撑 AI 系统落地:从模型部署到安全合规的企业级解决方案(一)
  • SQL分类详解:掌握DQL、DML、DDL等数据库语言类型
  • Java-Spring入门指南(二)利用IDEA手把手教你如何创建第一个Spring系统
  • Python学习-day4
  • win32diskimager强行缩减TF卡镜像制作尺寸的方法
  • Zynq中级开发七项必修课-第四课:S_AXI_HP0 高速端口访问 DDR
  • 整理期初数据用到的EXCEL里面的函数操作
  • 2026届长亭科技秋招正式开始
  • 炫酷JavaScript鼠标跟随特效
  • Nano Banana 新玩法超惊艳!附教程案例提示词!
  • CMake构建学习笔记23-SQLite库的构建
  • SQL Server 数据库创建与用户权限绑定
  • 构建下一代智能金融基础设施
  • 网络编程 05:UDP 连接,UDP 与 TCP 的区别,实现 UDP 消息发送和接收,通过 URL 下载资源
  • 网络传输的实际收发情况及tcp、udp的区别
  • python 创建websocket教程
  • 异常处理小妙招——1.别把“数据库黑话”抛给用户:论异常封装的重要性
  • GitHub每日最火火火项目(9.2)
  • 使用谷歌ai models/gemini-2.5-flash-image-preview 生成图片
  • Python/JS/Go/Java同步学习(第一篇)格式化/隐藏参数一锅端 四语言输出流参数宇宙(附源码/截图/参数表/避坑指南/老板沉默术)
  • 下载速度爆表,全平台通用,免费拿走!
  • Linux中断实验
  • VibeVoice 部署全指南:Windows 下的挑战与完整解决方案