当前位置: 首页 > news >正文

BCEWithLogitsLoss

在深度学习中,BCEWithLogitsLoss 是 PyTorch 等框架中用于二分类任务的核心损失函数,其本质是“Sigmoid 激活函数 + BCELoss(二元交叉熵损失)”的结合体。它解决了单独使用 Sigmoid + BCELoss 可能出现的数值不稳定问题,同时简化了代码实现,在目标检测(如 YOLOv5 的置信度/类别损失)、图像分割、多标签分类等场景中广泛应用。

一、核心定义:为什么需要 BCEWithLogitsLoss?

在理解 BCEWithLogitsLoss 前,先回顾两个基础组件:

  1. Sigmoid 激活函数:将模型输出的“原始 logits(无界实数)”压缩到 (0,1) 区间,得到概率值,公式为:
    σ(x)=11+e−x\sigma(x) = \frac{1}{1 + e^{-x}}σ(x)=1+ex1
    常用于二分类任务中,输出“样本属于正类的概率”。

  2. BCELoss(二元交叉熵损失):基于预测概率与真实标签(0 或 1)计算损失,衡量两者的“概率分布差异”,公式为:
    BCELoss(p,y)=−[y⋅log⁡(p)+(1−y)⋅log⁡(1−p)]\text{BCELoss}(p, y) = -[y \cdot \log(p) + (1-y) \cdot \log(1-p)]BCELoss(p,y)=[ylog(p)+(1y)log(1p)]
    其中,ppp 是 Sigmoid 输出的概率,yyy 是真实标签(0 表示负类,1 表示正类)。

单独使用 Sigmoid + BCELoss 的问题

当模型输出的 logits 绝对值较大时(如 x=10x=10x=10x=−10x=-10x=10):

  • Sigmoid 输出会趋近于 1 或 0,此时 log⁡(p)\log(p)log(p)log⁡(1−p)\log(1-p)log(1p) 会趋近于 −∞-\infty,导致数值下溢(浮点数精度丢失);
  • 反向传播时,梯度可能因 Sigmoid 的导数趋近于 0 而“消失”,模型难以优化。

BCEWithLogitsLoss 通过“将 Sigmoid 激活与 BCELoss 融合”,利用 LogSumExp 技巧避免了上述问题,同时保持计算效率。

二、数学原理:BCEWithLogitsLoss 的公式推导

BCEWithLogitsLoss 直接接收模型输出的 logits(未经过 Sigmoid 激活的原始值) 作为输入,内部先计算 Sigmoid,再计算交叉熵,但通过数学变形优化了数值稳定性。

1. 前向计算公式

设模型输出的 logits 为 xxx,真实标签为 yyyy∈{0,1}y \in \{0,1\}y{0,1}),则 BCEWithLogitsLoss 的损失公式为:
BCEWithLogitsLoss(x,y)=max⁡(x,0)−x⋅y+log⁡(1+e−∣x∣)\text{BCEWithLogitsLoss}(x, y) = \max(x, 0) - x \cdot y + \log(1 + e^{-|x|})BCEWithLogitsLoss(x,y)=max(x,0)xy+log(1+ex)

公式推导过程:

p=σ(x)=11+e−xp = \sigma(x) = \frac{1}{1 + e^{-x}}p=σ(x)=1+ex1 代入 BCELoss 公式:
BCELoss(σ(x),y)=−[y⋅log⁡(σ(x))+(1−y)⋅log⁡(1−σ(x))]\text{BCELoss}(\sigma(x), y) = -[y \cdot \log(\sigma(x)) + (1-y) \cdot \log(1-\sigma(x))]BCELoss(σ(x),y)=[ylog(σ(x))+(1y)log(1σ(x))]

通过对数和指数的数学变形(利用 log⁡(σ(x))=−log⁡(1+e−x)\log(\sigma(x)) = -\log(1 + e^{-x})log(σ(x))=log(1+ex)log⁡(1−σ(x))=−log⁡(1+ex)\log(1-\sigma(x)) = -\log(1 + e^{x})log(1σ(x))=log(1+ex)),最终可化简为上述带 max⁡(x,0)\max(x,0)max(x,0)log⁡(1+e−∣x∣)\log(1 + e^{-|x|})log(1+ex) 的形式。这种形式避免了直接计算 log⁡(σ(x))\log(\sigma(x))log(σ(x))log⁡(1−σ(x))\log(1-\sigma(x))log(1σ(x)),从根本上解决了数值下溢问题。

2. 梯度计算优势

由于损失函数与 Sigmoid 融合,梯度计算也被优化:
xxx 的梯度为 σ(x)−y\sigma(x) - yσ(x)y(与“单独 Sigmoid + BCELoss”的梯度结果一致),但避免了中间步骤的梯度消失,优化更稳定。

三、适用场景

BCEWithLogitsLoss 核心用于二分类任务,但通过“多标签扩展”可支持更复杂场景,典型应用包括:

应用场景具体说明
标准二分类样本仅属于“正类”或“负类”,如“图像是否包含猫”“邮件是否为垃圾邮件”。
多标签分类样本可同时属于多个类别,如“一张图片可能同时包含猫、狗、鸟”,每个类别独立判断“是/否”。
YOLOv5 置信度损失预测“锚框内是否存在目标”(二分类:有目标=1,无目标=0)。
YOLOv5 类别损失(多类)对多类别任务,采用“One-Hot 编码 + 每个类别独立二分类”,用 BCEWithLogitsLoss 计算每个类别的损失。
图像分割(语义/实例)预测“每个像素是否属于目标类别”(二分类,如背景=0,前景=1)。

四、PyTorch 中的关键参数与使用示例

在 PyTorch 中,torch.nn.BCEWithLogitsLoss 提供了灵活的参数配置,以适配不同场景的需求。

1. 核心参数

参数名类型作用
weightTensor(可选)类别权重,形状为 (C,)(C 为类别数),用于解决类别不平衡(如正样本少,给正样本更高权重)。
pos_weightTensor(可选)正样本权重,形状为 (C,),仅对正样本(y=1y=1y=1)的损失加权,进一步调整正样本的损失贡献。
reductionstr损失聚合方式:
- 'mean'(默认):返回所有样本的平均损失;
- 'sum':返回所有样本的总损失;
- 'none':返回每个样本的损失(保留原始形状)。

2. 使用示例(以 YOLOv5 置信度损失为例)

假设在 YOLOv5 中,需要计算“锚框是否有目标”的置信度损失:

  • 模型输出的置信度 logits 形状为 (B, 25200)(B 为批量大小,25200 为总锚框数);
  • 真实标签形状为 (B, 25200)(有目标=1,无目标=0);
  • 由于数据集中“无目标锚框”远多于“有目标锚框”(类别不平衡),可通过 pos_weight 提升正样本损失权重。
import torch
import torch.nn as nn# 1. 初始化 BCEWithLogitsLoss(设置正样本权重,缓解类别不平衡)
pos_weight = torch.tensor([2.0])  # 正样本损失权重为2.0(根据数据调整)
bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean')# 2. 模拟输入(YOLOv5 置信度 logits 和真实标签)
batch_size = 2
num_anchors = 25200  # 总锚框数(640×640输入时)
pred_logits = torch.randn(batch_size, num_anchors)  # 模型输出的置信度 logits(无界)
true_labels = torch.randint(0, 2, (batch_size, num_anchors)).float()  # 真实标签(0或1)# 3. 计算损失
loss = bce_loss(pred_logits, true_labels)
print(f"置信度损失值: {loss.item()}")  # 输出损失数值(如0.65)

五、与 BCELoss 的对比

BCEWithLogitsLoss 与 BCELoss 均用于二分类,但核心差异在于“是否包含 Sigmoid 激活”,具体对比如下:

对比维度BCEWithLogitsLossBCELoss
输入要求接收未激活的 logits(无界实数)接收已激活的概率(需在 (0,1) 区间)
数值稳定性高(融合 Sigmoid,避免下溢)低(直接计算 log§,易下溢)
代码复杂度低(一步完成激活+损失计算)高(需手动添加 Sigmoid 层)
适用场景二分类/多标签分类(推荐优先使用)仅在已获取概率输出时使用(如特殊定制场景)

结论:在绝大多数二分类任务中,优先使用 BCEWithLogitsLoss,而非 Sigmoid + BCELoss,以保证数值稳定和计算效率。

六、在 YOLOv5 中的具体应用

在 YOLOv5 中,BCEWithLogitsLoss 主要用于两个损失计算场景:

  1. 置信度损失:判断每个锚框“是否包含目标”(二分类),通过 pos_weight 缓解“无目标锚框过多”的类别不平衡问题;
  2. 类别损失:对多类别任务(如 COCO 80 类),采用“One-Hot 编码 + 每个类别独立二分类”,每个类别用 BCEWithLogitsLoss 计算损失,最终求和得到总类别损失。

这种设计既适配了 YOLOv5 的“锚框预测”逻辑,又能高效处理类别不平衡问题,是模型训练稳定收敛的关键之一。

总结

BCEWithLogitsLoss 是二分类任务的“最优解”之一,其核心价值在于融合激活与损失计算、优化数值稳定性,同时通过 weightpos_weight 适配类别不平衡场景。在目标检测(如 YOLOv5)、多标签分类、图像分割等任务中,它是实现高精度模型的重要工具,理解其原理和参数配置对模型调优至关重要。

通俗讲解:

一张表格看懂区别

类型标签形状标签内容输出激活函数损失函数例子:一张图里有猫、狗
多类单标签
(multi-class, single-label)
1 个数0~C-1 中的 一个下标SoftmaxCrossEntropyLoss只能选 1 个:cat=1, dog=2 → 标签=1 或 2
多标签
(multi-label)
C 个 0/1每个类别独立 0/1SigmoidBCEWithLogitsLoss可以同时有:cat=1, dog=1, 其余=0

一句话

“多标签”就是:一张图可以被打上多个标签,每个标签单独判断「有/没有」——于是每个输出位独立做二分类,用 BCEWithLogitsLoss。


生活例子

假设只有 4 类:猫、狗、汽车、人。

图片多类单标签标签多标签标签(4 位)
只有猫0[1, 0, 0, 0]
猫 + 狗不允许[1, 1, 0, 0]
猫 + 狗 + 人不允许[1, 1, 0, 1]

多标签允许“同时 1”,所以叫 “每位独立 0/1”


代码级对比(4 类)

多类单标签

logits = model(img)          # shape (B, 4)
loss = nn.CrossEntropyLoss()(logits, label)  # label 是整数 0~3

多标签

logits = model(img)          # shape (B, 4)
target = torch.tensor([[1, 1, 0, 0]], dtype=torch.float)  # 4 位 0/1
loss = nn.BCEWithLogitsLoss()(logits, target)  # 对 4 位分别算 BCE

再浓缩成一句

“每位独立 0/1” = 一张图可以同时拥有多个类别,每个类别单独用 sigmoid 变成概率,再单独用 BCE 算损失——这就是 BCEWithLogitsLoss 干的事。

http://www.dtcms.com/a/419772.html

相关文章:

  • 在线设计网站大全网站建设方案推销
  • CUDA框架
  • 辽阳专业建设网站公司wordpress rss 爬取
  • TypeScript 简介与项目中配置
  • 南宁seo建站seo网站优化排名
  • 【每日一问】老化测试有什么作用?
  • 广州信科做网站dede 门户网站
  • 【JDBC】系列文章第一章,怎么在idea中连接数据库,并操作插入数据?
  • 企业的网站建设朔州网站建设收费
  • 外贸上哪个网站开发客户网站建设费可分摊几年
  • 8. mutable 的用法
  • 做网站 php j2ee做网站投注员挣钱吗
  • 试玩平台网站开发录入客户信息的软件
  • 网站建设谈单情景对话wordpress外网访问错误
  • 怎么学网站开发海阳网站制作
  • 肥东建设局网站家具设计师常去的网站
  • 查网站开通时间网站设计 职业
  • 重庆网站优化搜索引擎优化包括( )方面的优化
  • 助力工业转型升级 金士顿工博会大放异彩
  • 智慧校园智能一卡通管理系统的完整架构与功能模块设计,结合技术实现与应用场景,分为核心平台、功能子系统及扩展应用三部分
  • @[TOC](【笔试强训】Day02) # 1. ⽜⽜的快递(模拟) [题⽬链接: BC64 ⽜⽜的快递]
  • 广州魔站建站3d演示中国空间站建造历程
  • MySQL数据库——13.2.2 JDBC编程-鑫哥演示使用过程
  • AWS实战:轻松创建弹性IP,实现固定公网IP地址
  • 网站制作谁家好vps可以做wordpress和ssr
  • 全能企业网站管理系统Wordpress百万访问优化
  • 东南亚日本股票数据API对接文档
  • 吴*波频道推荐书单
  • 关于排查问题的总结
  • 优雅动听的歌曲之一-小城画师