当前位置: 首页 > news >正文

关于FocalLoss 损失函数

🎯 Focal Loss 深度讲解


🧠 一、直观动机:为什么要 Focal Loss?

传统的 CrossEntropy Loss(交叉熵) 是分类任务的默认选择,它关注的是模型预测正确标签的概率(越高越好)。公式如下:

CE ( p t ) = − log ⁡ ( p t ) \text{CE}(p_t) = -\log(p_t) CE(pt)=log(pt)

  • 如果预测得好( p t → 1 p_t \to 1 pt1),损失就很小
  • 如果预测得差( p t → 0 p_t \to 0 pt0),损失就很大

问题来了:
在很多实际场景里,比如目标检测、语义分割、医学图像,存在严重的类别不平衡现象:

  • 正类样本(比如肿瘤区域、小目标)非常稀少
  • 背景类(负类)大量存在,且模型很容易把它分类正确

这导致了什么?

✅ 模型很快学会把大多数都预测成“背景”就能拿到很小的 loss
❌ 正类样本虽然难,但数量少,对 loss 的贡献低,模型懒得理它!

我们要的是什么?

专注于难样本、忽略容易的样本!

这正是 Focal Loss 的使命。


🧮 二、Focal Loss 的数学表达式和直觉解释

⭐ 标准交叉熵(二分类):

CE ( p t ) = − log ⁡ ( p t ) \text{CE}(p_t) = -\log(p_t) CE(pt)=log(pt)

其中:

  • p t = p p_t = p pt=p if label is 1
  • p t = 1 − p p_t = 1 - p pt=1p if label is 0

⭐ Focal Loss:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

参数含义:

  • p t p_t pt:模型对正确类的预测概率
  • α t \alpha_t αt:类别平衡系数
  • γ \gamma γ:focusing parameter,用来衰减简单样本的贡献

🔍 三个部分的直觉理解:

  1. log ⁡ ( p t ) \log(p_t) log(pt):就是交叉熵,衡量你预测的好不好

  2. ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ

    • p t p_t pt 很大(预测正确),这个因子趋近于 0 → loss 减小
    • p t p_t pt 很小(预测错误),这个因子接近 1 → 保留大 loss
      关注难样本,忽略易样本!
  3. α t \alpha_t αt:用于调节类别之间的权重(比如正样本少,那正样本的 α \alpha α 设大点)


🔢 三、PyTorch 实现(二分类)

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: [B, 1] or [B], logits
        # targets: [B], 0 or 1

        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)  # pt = p_t

        focal_weight = (1 - pt) ** self.gamma
        alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)

        loss = alpha_t * focal_weight * BCE_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

注意:这里 inputs 是 logits,不是 sigmoid 之后的概率!


🧪 四、实际应用场景

应用场景类别不平衡?是否适合用 Focal Loss?
目标检测(小目标)✅严重不平衡✅ 推荐
医学图像分割(肿瘤)✅非常不平衡✅ 强烈推荐
二分类、异常检测任务✅有时不平衡✅ 可尝试
多分类图像识别❌样本相对平衡❌ 不一定需要

🧩 五、和其他 Loss 的对比

损失函数适用场景特点
CrossEntropy通用简单直接
Weighted CE类别不平衡人为加权正负类,但不能区分难易样本
Dice Loss图像分割关注前景 IoU,但不处理类别不平衡
Focal Loss类别极度不平衡聚焦难样本,动态调节 loss 大小

🎯 六、Focal Loss 常见问题

1. γ \gamma γ 设多少合适?

  • 一般设为 2,试试 1~5 范围微调
  • γ \gamma γ 越大,越“懒得理”那些简单样本

2. α \alpha α 必须设吗?

  • 如果类别比例非常悬殊,比如 1:100,那建议设置
  • 常用设置:正类 0.25,负类 0.75

3. 可不可以和 Dice Loss 混合?

  • 可以,特别是在图像分割中常见组合:
loss = focal_loss + dice_loss

✅ 七、一句话终结

Focal Loss = 聚焦困难样本的动态加权交叉熵损失,专为不平衡场景设计,用得对了就是神器!

相关文章:

  • 【C++算法】54.链表_合并 K 个升序链表
  • Ansible:role企业级实战
  • 4-6记录(B树)
  • 使用ZYNQ芯片和LVGL框架实现用户高刷新UI设计系列教程(第七讲)
  • 【React】副作用 setState执行流程 内置钩子(Effect Callback Reducer)React.memo
  • 从 STP 到 RSTP 再到 MSTP:网络生成树协议的工作机制与发展
  • Docker部署.NetCore8项目
  • 【Axure视频教程】中继器表格轮播含暂停效果
  • 蓝桥杯真题:数字串个数
  • 【今日三题】小乐乐改数字 (模拟) / 十字爆破 (预处理+模拟) / 比那名居的桃子 (滑窗 / 前缀和)
  • Spring Security6 从源码慢速开始
  • 系统思考—提升解决动态性复杂问题能力
  • C++对象生命周期管理:从构造到析构的完整指南
  • Unity Addressables资源生命周期自动化监控技术详解
  • 【智能指针】—— 我与C++的不解之缘(三十三)
  • 02-redis-源码下载
  • mysql-锁的算法(记录锁、间隙锁、临键锁)
  • 【电商】基于LangChain框架将多模态大模型连接数据库实现精准识别
  • 基于CNN-GRU的深度Q网络(Deep Q-Network,DQN)求解移动机器人路径规划,MATLAB代码
  • 【js面试题】new操作做了什么?