关键词解释:Focal Loss解决类别极度不平衡问题而设计的损失函数
简单来说就是数据集中可能有类别不平衡的问题,Focal Loss自动降低容易分类样本的权重更关注少样本的数据
Focal Loss 是一种专门为解决类别极度不平衡问题而设计的损失函数,由何恺明(Kaiming He)等人在 2017 年的论文《Focal Loss for Dense Object Detection》(用于 RetinaNet)中提出。它在目标检测、医学图像分割、罕见事件预测等场景中非常有效。
下面从动机、公式推导、性质、代码实现和使用建议五个方面详细讲解。
一、动机:为什么需要 Focal Loss?
在类别极度不平衡的数据集中(例如:99% 负样本,1% 正样本),传统交叉熵损失(Cross-Entropy, CE)存在以下问题:
- 大量简单负样本(easy negatives)主导了梯度更新,导致模型“只学会预测负类”。
- 虽然正样本少,但它们往往更重要(如癌症检测中的阳性病例)。
- 模型很快对多数类过拟合,而对少数类学习不足。
🎯 核心思想:让损失函数自动降低容易分类样本的权重,使模型更关注难分类样本(尤其是少数类中的难例)。
二、Focal Loss 公式详解
1. 回顾标准交叉熵(Binary CE)
对于二分类,预测概率为 ,真实标签
,定义:
则交叉熵损失为:
2. 引入 Focal Loss
Focal Loss 在 CE 基础上引入两个关键因子:
✅ (1) 调制因子(Modulating Factor):
- 当样本被正确分类时,
,则
,损失被大幅降低。
- 当样本难以分类时,
,则
,损失几乎不变。
是聚焦参数(focusing parameter):
→ 退化为标准 CE。
越大,对易分样本的抑制越强(通常取 1~5)。
✅ (2) 平衡因子(Balancing Factor):
- 用于调整正负样本的权重,类似 class weight。
通常设为少数类的权重(如
表示更关注正类)。
3. 最终公式
其中:
:如上定义的“目标类预测概率”
:类别权重(如正类用
,负类用
)
:聚焦参数(控制难易样本的权重差异)
三、Focal Loss 的性质
| 性质 | 说明 |
|---|---|
| 自适应加权 | 自动降低高置信度(易分)样本的损失贡献 |
| 缓解类别不平衡 | 无需过采样/欠采样,直接在损失层面处理 |
| 关注难例 | 类似于在线难例挖掘(OHEM),但更平滑、可微 |
| 超参数敏感 |
可视化对比(CE vs Focal Loss)
假设真实标签为正类(),则
:
- 当
(易分正样本):
- 当
(难分正样本):
👉 易分样本损失被压低两个数量级!
四、PyTorch 实现
方法 1:手动实现(推荐理解)
import torch
import torch.nn.functional as Fdef focal_loss(inputs, targets, alpha=0.25, gamma=2.0, reduction='mean'):"""inputs: [N, C] 或 [N](logits)targets: [N](long tensor,类别索引)"""# 如果 inputs 是 logits,先转为概率if inputs.dim() > 1:# 多分类ce_loss = F.cross_entropy(inputs, targets, reduction='none')p = torch.exp(-ce_loss) # p_t = exp(-CE) = softmax prob of true classfocal_weight = alpha * (1 - p) ** gammaloss = focal_weight * ce_losselse:# 二分类(inputs 是 logit)p = torch.sigmoid(inputs)targets = targets.float()p_t = p * targets + (1 - p) * (1 - targets)alpha_t = alpha * targets + (1 - alpha) * (1 - targets)ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')focal_weight = alpha_t * (1 - p_t) ** gammaloss = focal_weight * ce_lossif reduction == 'mean':return loss.mean()elif reduction == 'sum':return loss.sum()else:return loss
方法 2:使用 torchvision 或第三方库
# 使用 torchvision.ops.sigmoid_focal_loss(仅支持二分类/多标签)
from torchvision.ops import sigmoid_focal_loss# inputs: [N, C] logits
# targets: [N, C] one-hot 或 0/1
loss = sigmoid_focal_loss(inputs, targets, alpha=0.25, gamma=2.0, reduction='mean')
五、使用建议
| 场景 | 建议 |
|---|---|
| 目标检测(如 RetinaNet) | 默认使用 Focal Loss( |
| 医学图像分割 | 正样本(病灶)极少,强烈推荐 |
| 欺诈检测 / 故障预测 | 正例 < 1%,Focal Loss 效果显著 |
| 多分类不平衡 | 可扩展为每个类别设置不同 |
| 调参技巧 | 先固定 |
⚠️ 注意:Focal Loss 不能完全替代数据重采样。在极端不平衡(如 1:10000)时,仍需结合过采样或阈值调整。
六、扩展:Focal Loss 的变体
- Class-balanced Focal Loss:根据有效样本数动态调整
- GHM (Gradient Harmonizing Mechanism):另一种难例挖掘思路
- PolyLoss:将 CE 与多项式项结合,效果有时优于 Focal Loss
总结
Focal Loss 的核心价值在于:
“让模型把精力花在刀刃上”——自动忽略大量简单样本,聚焦于难分类的、尤其是少数类中的关键样本。
它不仅是损失函数的改进,更是一种自适应难例挖掘机制,在不平衡学习中具有里程碑意义。
