如何安全的计算softmax?
Softmax 函数广泛用于分类任务的输出层。但直接计算容易导致 数值不稳定 —— 轻则结果错误,重则出现 inf
或 nan
。
🚨 问题:为什么原始 Softmax 不安全?
标准 Softmax 公式:
softmax(xi)=exi∑j=1nexjsoftmax(xi)=∑j=1nexjexi
❌ 危险场景举例:
假设输入向量中有大数:
- x=[1000,2000,3000]
- 计算 e3000→ 远超浮点数表示范围(溢出为
inf
) - 结果变成:
inf / inf = nan
→ 模型崩溃!
即使不溢出,小数也可能被“淹没”,造成精度丢失。
✅ 解法:使用 Log-Sum-Exp 技巧 + 减去最大值
✅ 安全版 Softmax(Numerically Stable)
import numpy as npdef stable_softmax(x):# Step 1: 减去最大值(关键!)z = x - np.max(x, axis=-1, keepdims=True)# Step 2: 计算 exp(z),此时所有元素 ≤ 0,不会溢出exp_z = np.exp(z)# Step 3: 归一化return exp_z / np.sum(exp_z, axis=-1, keepdims=True)
🔍 原理说明
我们利用一个数学恒等变换:
softmax(xi)=exi∑jexj=exi−c∑jexj−c,对任意 csoftmax(xi)=∑jexjexi=∑jexj−cexi−c,对任意 c
选择 c=max(x1,x2,...,xn)c=max(x1,x2,...,xn),则:
- 所有 xi−c≤0xi−c≤0
- exi−c≤1exi−c≤1
- 最大指数项为 e0=1e0=1,避免了上溢
✅ 这是几乎所有深度学习框架(PyTorch、TensorFlow)内部实现的方式。
原则 | 说明 |
---|---|
1. 减去最大值 | xi←xi−max(x),防止指数溢出 |
2. 使用双精度 | 用 float64 替代 float32 提高稳定性 |
3. 在 log 空间操作 | 如需取对数,优先使用 logsumexp 技巧 |