【Pytorch】分类问题交叉熵
1️⃣ 为什么分类问题不用 MSE(均方误差)?
表格
复制
场景 | 标签 | 预测 | MSE 损失 |
---|---|---|---|
分类 | [0,0,1] | [0.3,0.3,0.4] | (0.4-1)²=0.36 |
分类 | [0,0,1] | [0.1,0.2,0.7] | (0.7-1)²=0.09 |
看起来合理,但:
-
sigmoid/softmax 输出在 0/1 附近梯度几乎为零 → 梯度消失;
-
MSE 把“概率差”当数值差” → 不符合概率直觉;
-
收敛慢,还容易卡在鞍点。
2️⃣ 交叉熵(Cross Entropy)思想
一句话:衡量「真实分布 p」与「预测分布 q」之间的信息差距。
公式(离散版):
CE(p,q) = − Σ p(i) log q(i)
-
p 是 one-hot 标签(比如 [0,0,1])
-
q 是 softmax 输出(比如 [0.1,0.2,0.7])
因为 p 只有一个 1,其余为 0,所以求和只剩一项:
CE = − log q(正确类)
直观:
-
若 q(正确类)=0.7 → CE ≈ 0.36
-
若 q(正确类)=0.98 → CE ≈ 0.02
预测越准,损失越小,且梯度不饱和(后面会算给你看)。
3️⃣ 手推一条二分类例子
表格
复制
样本 | 真实 y | 预测 p |
---|---|---|
猫 | 1 | 0.8 |
狗 | 0 | 0.1 |
二元交叉熵(BCE):
L = − [y log p + (1−y) log(1−p)]
样本 1(猫):
L = − [1·log0.8 + 0·log0.2] = −log0.8 ≈ 0.223
样本 2(狗):
L = − [0·log0.1 + 1·log0.9] = −log0.9 ≈ 0.105
平均损失 ≈ 0.164,预测越离谱,值越大。
4️⃣ PyTorch 一行代码算完
Python
import torch.nn.functional as Flogits = torch.tensor([[1.0, 2.0, 0.5]]) # 模型输出(未归一化)
target = torch.tensor([1]) # 正确类别索引loss = F.cross_entropy(logits, target)
print(loss.item()) # tensor(0.8309)
内部干了啥:
-
softmax(logits)
→ 概率 -
log(softmax)
→ 对数概率 -
取
-log q(正确类)
→ 损失
5️⃣ 数值稳定性技巧
不要手写:
Python
复制
prob = F.softmax(logits)
log_prob = torch.log(prob)
loss = F.nll_loss(log_prob, target)
推荐直接用:
Python
loss = F.cross_entropy(logits, target)
内部实现 log-sum-exp 技巧,避免 log(softmax)
造成数值溢出。
6️⃣ 对比实验(直观感受)
表格
复制
方法 | 损失曲线 | 梯度大小 | 收敛速度 |
---|---|---|---|
MSE | 平坦区早 | 极小 | 慢 |
CE | 无平坦区 | 稳定 | 快 |
7️⃣ 小结口诀(背下来)
分类用 CE,回归用 MSE
CE = −log q(对类)
PyTorch:F.cross_entropy(logits, target)
别手写 softmax+log!
8️⃣ 课后 5 分钟动手
-
用
F.cross_entropy
算一条三分类样本。 -
把
logits
乘 10 再算一次,观察损失变化。 -
对比
F.mse_loss
与F.cross_entropy
的梯度大小(.grad
)。