关于FocalLoss 损失函数
🎯 Focal Loss 深度讲解
🧠 一、直观动机:为什么要 Focal Loss?
传统的 CrossEntropy Loss(交叉熵) 是分类任务的默认选择,它关注的是模型预测正确标签的概率(越高越好)。公式如下:
CE ( p t ) = − log ( p t ) \text{CE}(p_t) = -\log(p_t) CE(pt)=−log(pt)
- 如果预测得好( p t → 1 p_t \to 1 pt→1),损失就很小
- 如果预测得差( p t → 0 p_t \to 0 pt→0),损失就很大
问题来了:
在很多实际场景里,比如目标检测、语义分割、医学图像,存在严重的类别不平衡现象:
- 正类样本(比如肿瘤区域、小目标)非常稀少
- 背景类(负类)大量存在,且模型很容易把它分类正确
这导致了什么?
✅ 模型很快学会把大多数都预测成“背景”就能拿到很小的 loss
❌ 正类样本虽然难,但数量少,对 loss 的贡献低,模型懒得理它!
我们要的是什么?
专注于难样本、忽略容易的样本!
这正是 Focal Loss 的使命。
🧮 二、Focal Loss 的数学表达式和直觉解释
⭐ 标准交叉熵(二分类):
CE ( p t ) = − log ( p t ) \text{CE}(p_t) = -\log(p_t) CE(pt)=−log(pt)
其中:
- p t = p p_t = p pt=p if label is 1
- p t = 1 − p p_t = 1 - p pt=1−p if label is 0
⭐ Focal Loss:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
参数含义:
- p t p_t pt:模型对正确类的预测概率
- α t \alpha_t αt:类别平衡系数
- γ \gamma γ:focusing parameter,用来衰减简单样本的贡献
🔍 三个部分的直觉理解:
-
log ( p t ) \log(p_t) log(pt):就是交叉熵,衡量你预测的好不好
-
( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ:
- 当 p t p_t pt 很大(预测正确),这个因子趋近于 0 → loss 减小
- 当
p
t
p_t
pt 很小(预测错误),这个因子接近 1 → 保留大 loss
→ 关注难样本,忽略易样本!
-
α t \alpha_t αt:用于调节类别之间的权重(比如正样本少,那正样本的 α \alpha α 设大点)
🔢 三、PyTorch 实现(二分类)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
# inputs: [B, 1] or [B], logits
# targets: [B], 0 or 1
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
probs = torch.sigmoid(inputs)
pt = torch.where(targets == 1, probs, 1 - probs) # pt = p_t
focal_weight = (1 - pt) ** self.gamma
alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)
loss = alpha_t * focal_weight * BCE_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
注意:这里
inputs
是 logits,不是 sigmoid 之后的概率!
🧪 四、实际应用场景
应用场景 | 类别不平衡? | 是否适合用 Focal Loss? |
---|---|---|
目标检测(小目标) | ✅严重不平衡 | ✅ 推荐 |
医学图像分割(肿瘤) | ✅非常不平衡 | ✅ 强烈推荐 |
二分类、异常检测任务 | ✅有时不平衡 | ✅ 可尝试 |
多分类图像识别 | ❌样本相对平衡 | ❌ 不一定需要 |
🧩 五、和其他 Loss 的对比
损失函数 | 适用场景 | 特点 |
---|---|---|
CrossEntropy | 通用 | 简单直接 |
Weighted CE | 类别不平衡 | 人为加权正负类,但不能区分难易样本 |
Dice Loss | 图像分割 | 关注前景 IoU,但不处理类别不平衡 |
Focal Loss | 类别极度不平衡 | 聚焦难样本,动态调节 loss 大小 |
🎯 六、Focal Loss 常见问题
1. γ \gamma γ 设多少合适?
- 一般设为 2,试试 1~5 范围微调
- γ \gamma γ 越大,越“懒得理”那些简单样本
2. α \alpha α 必须设吗?
- 如果类别比例非常悬殊,比如 1:100,那建议设置
- 常用设置:正类 0.25,负类 0.75
3. 可不可以和 Dice Loss 混合?
- 可以,特别是在图像分割中常见组合:
loss = focal_loss + dice_loss
✅ 七、一句话终结
Focal Loss = 聚焦困难样本的动态加权交叉熵损失,专为不平衡场景设计,用得对了就是神器!