交叉熵损失
交叉熵损失函数原理与使用
2. 二分类问题
- 网络结构
最后一层使用nn.Linear
+Sigmoid
- PyTorch 实现
criterion = nn.BCEWithLogitsLoss() # 自动处理 sigmoid loss = criterion(predictions, labels.float()) # labels 为 0/1
3. 多标签分类
- 网络结构
每个类别独立使用Sigmoid
激活 - PyTorch 实现
criterion = nn.BCEWithLogitsLoss() loss = criterion(predictions, labels) # labels 为多热编码
⚠️ 关键注意事项
-
数值稳定性
- 框架自动处理
log(0)
问题(如添加微小值eps=1e-8
) - 优先使用
LogSoftmax + NLLLoss
或框架内置函数(如CrossEntropyLoss
)
- 框架自动处理
-
类别不平衡
- 通过
weight
参数调整类别权重weights = torch.tensor([0.1, 0.9]) # 类别权重 criterion = nn.CrossEntropyLoss(weight=weights)
- 通过
-
输入格式要求
函数 预测值形状 标签格式 CrossEntropyLoss
(N, C)
类别索引 (N,)
BCEWithLogitsLoss
(N, *)
0/1 矩阵 (N,*)
🔄 与 MSE 的对比
特性 | 交叉熵损失 | 均方误差(MSE) |
---|---|---|
适用场景 | 分类任务 | 回归任务 |
梯度特性 | 错误预测时梯度大 | 预测接近极值时梯度小 |
收敛速度 | 更快(分类任务) | 较慢 |
概率解释性 | 直接优化概率分布差异 | 无明确概率意义 |
📈 代码示例(PyTorch)
单标签分类
import torch
import torch.nn as nn
# 定义模型输出和标签
logits = torch.randn(3, 5) # 3个样本,5个类别
labels = torch.tensor([2, 0, 4]) # 真实类别索引
# 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(loss.item())
二分类
predictions = torch.randn(3) # 3个样本的二分类 logits
labels = torch.tensor([1, 0, 1]) # 真实标签
criterion = nn.BCEWithLogitsLoss()
loss = criterion(predictions, labels.float())
print(loss.item())
📌 总结
- 核心价值:通过量化概率分布差异,驱动模型输出高置信度的正确预测
- 最佳实践:优先使用框架内置函数,正确处理类别不平衡和数值稳定性
- 延伸扩展:结合 Focal Loss 处理难样本,或与 Label Smoothing 配合提升泛化性