注意力机制中除以Dk的方差归一化
要理解注意力机制中除以 dk\sqrt{d_k}dk 的数学本质,需从点积的方差问题切入 —— 这是避免 softmax 输出极端化、保证模型可学习性的关键。下面将分 “核心问题推导”“为什么是 dk\sqrt{d_k}dk”“为什么不是 dkd_kdk 或 dk2d_k^2dk2” 三部分解释,并结合具体数学例子验证。
一、核心问题:未缩放的点积会导致方差爆炸
注意力机制(以缩放点积注意力为例)的核心计算是查询向量 q\mathbf{q}q(维度 dkd_kdk)与键向量 k\mathbf{k}k(维度 dkd_kdk)的点积:
q⋅k=∑i=1dkqiki\mathbf{q} \cdot \mathbf{k} = \sum_{i=1}^{d_k} q_i k_iq⋅k=∑i=1dkqiki
为了分析其数值特性,我们做合理假设(符合模型初始化惯例):
-
q\mathbf{q}q 和 k\mathbf{k}k 的每个元素 qi,kiq_i, k_iqi,ki 是独立同分布(i.i.d.) 的随机变量;
-
均值为 000(消除偏移影响),方差为 111(标准化初始化,如 Xavier 初始化);
-
q\mathbf{q}q 和 k\mathbf{k}k 相互独立(查询与键来自不同分支)。
1.1 点积的方差计算
根据方差的性质:
-
独立变量之和的方差 = 各变量方差之和(即 Var(a+b)=Var(a)+Var(b)\text{Var}(a+b) = \text{Var}(a) + \text{Var}(b)Var(a+b)=Var(a)+Var(b),当 a,ba,ba,b 独立时);
-
独立变量乘积的方差:Var(qiki)=E[(qiki)2]−(E[qiki])2\text{Var}(q_i k_i) = \mathbb{E}[(q_i k_i)^2] - (\mathbb{E}[q_i k_i])^2Var(qiki)=E[(qiki)2]−(E[qiki])2。由于 E[qi]=E[ki]=0\mathbb{E}[q_i] = \mathbb{E}[k_i] = 0E[qi]=E[ki]=0,且 qi,kiq_i, k_iqi,ki 独立,故 E[qiki]=E[qi]E[ki]=0\mathbb{E}[q_i k_i] = \mathbb{E}[q_i] \mathbb{E}[k_i] = 0E[qiki]=E[qi]E[ki]=0,且 E[(qiki)2]=E[qi2]E[ki2]\mathbb{E}[(q_i k_i)^2] = \mathbb{E}[q_i^2] \mathbb{E}[k_i^2]E[(qiki)2]=E[qi2]E[ki2]。
又因为 Var(qi)=E[qi2]−(E[qi])2=E[qi2]=1\text{Var}(q_i) = \mathbb{E}[q_i^2] - (\mathbb{E}[q_i])^2 = \mathbb{E}[q_i^2] = 1Var(qi)=E[qi2]−(E[qi])2=E[qi2]=1(同理 E[ki2]=1\mathbb{E}[k_i^2] = 1E[ki2]=1),因此:
Var(qiki)=1×1=1\text{Var}(q_i k_i) = 1 \times 1 = 1Var(qiki)=1×1=1
最终,点积 q⋅k\mathbf{q} \cdot \mathbf{k}q⋅k 是 dkd_kdk 个独立变量 qikiq_i k_iqiki 之和,其方差为:
Var(q⋅k)=∑i=1dkVar(qiki)=dk×1=dk\text{Var}(\mathbf{q} \cdot \mathbf{k}) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k \times 1 = d_kVar(q⋅k)=∑i=1dkVar(qiki)=dk×1=dk
1.2 方差爆炸的危害
当 dkd_kdk 增大时(如 Transformer 中 dk=64d_k=64dk=64 或更大),点积的方差会随 **** **** 线性增长,导致点积的数值范围急剧扩大(例如 dk=64d_k=64dk=64 时,方差 = 64,3σ 范围为 [−24,24][-24, 24][−24,24])。
此时经过 softmax 函数(softmax(x)=ex∑ex\text{softmax}(x) = \frac{e^x}{\sum e^x}softmax(x)=∑exex)会出现严重问题:
-
若点积数值极大,exe^xex 会指数级增大,导致该位置的概率接近 111,其他位置接近 000(分布 “尖锐”);
-
若点积数值极小,exe^xex 接近 000,概率也接近 000。
这种极端分布会导致 梯度消失(softmax 在概率接近 000 或 111 时梯度趋近于 000),模型无法学习到有效的注意力权重。
二、为什么除以 dk\sqrt{d_k}dk:方差归一化
我们的目标是将点积的方差归一化到常数(与 dkd_kdk 无关),确保 softmax 输出分布平缓、梯度正常。
对於点积进行缩放:q⋅kdk\frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d_k}}dkq⋅k,计算其方差:
Var(q⋅kdk)=1dk×Var(q⋅k)\text{Var}\left( \frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d_k}} \right) = \frac{1}{d_k} \times \text{Var}(\mathbf{q} \cdot \mathbf{k})Var(dkq⋅k)=dk1×Var(q⋅k)
代入 Var(q⋅k)=dk\text{Var}(\mathbf{q} \cdot \mathbf{k}) = d_kVar(q⋅k)=dk,可得:
Var(q⋅kdk)=1dk×dk=1\text{Var}\left( \frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d_k}} \right) = \frac{1}{d_k} \times d_k = 1Var(dkq⋅k)=dk1×dk=1
无论 dkd_kdk 取何值,缩放后的点积方差始终为 111,数值范围稳定在 [−3,3][-3, 3][−3,3](3σ 范围)。此时 softmax 能输出合理的概率分布,既可以区分不同键的重要性,又不会导致梯度消失。
三、为什么不是 dkd_kdk 或 dk2d_k^2dk2:方差过小导致注意力失效
若选择其他缩放因子,会破坏方差平衡,导致注意力机制失去作用。我们通过方差计算和具体例子验证:
3.1 若除以 dkd_kdk:方差随 dkd_kdk 减小
缩放后的方差为:
Var(q⋅kdk)=1dk2×dk=1dk\text{Var}\left( \frac{\mathbf{q} \cdot \mathbf{k}}{d_k} \right) = \frac{1}{d_k^2} \times d_k = \frac{1}{d_k}Var(dkq⋅k)=dk21×dk=dk1
当 dkd_kdk 增大时,方差会随 **** **** 反比例减小(例如 dk=64d_k=64dk=64 时,方差 = 1/64,3σ 范围为 [−0.9375,0.9375][-0.9375, 0.9375][−0.9375,0.9375])。
此时点积数值会集中在 000 附近,softmax 后所有位置的概率接近 1n\frac{1}{n}n1(nnn 为键的数量),形成 “均匀分布”—— 注意力无法突出重要键,完全失效。
3.2 若除以 dk2d_k^2dk2:方差急剧减小
缩放后的方差为:
Var(q⋅kdk2)=1dk4×dk=1dk3\text{Var}\left( \frac{\mathbf{q} \cdot \mathbf{k}}{d_k^2} \right) = \frac{1}{d_k^4} \times d_k = \frac{1}{d_k^3}Var(dk2q⋅k)=dk41×dk=dk31
当 dk=64d_k=64dk=64 时,方差 = 1/(64³)≈3.8×10⁻⁵,3σ 范围仅为 [−0.0058,0.0058][-0.0058, 0.0058][−0.0058,0.0058]。
点积数值几乎全部集中在 000,softmax 后概率完全均匀(如 3 个键的概率接近 0.3330.3330.333),注意力机制彻底失效。
四、数学例子:直观验证不同缩放因子的影响
假设 dk=16d_k=16dk=16(常见维度),q\mathbf{q}q 和 k\mathbf{k}k 的元素均服从 N(0,1)\mathcal{N}(0,1)N(0,1)(均值 0,方差 1),且有 3 个键(即 3 个点积结果)。
步骤 1:生成未缩放的点积
假设 3 个点积(未缩放)为:[12,8,10][12, 8, 10][12,8,10](符合方差 = 16,数值范围 [−12,12][-12, 12][−12,12])。
步骤 2:不同缩放因子的结果对比
缩放因子 | 缩放后的点积 | softmax 概率分布 | 分布特点 |
---|---|---|---|
不缩放 | [12,8,10][12, 8, 10][12,8,10] | [0.867,0.016,0.117][0.867, 0.016, 0.117][0.867,0.016,0.117] | 极端尖锐 |
除以 16=4\sqrt{16}=416=4 | [3,2,2.5][3, 2, 2.5][3,2,2.5] | [0.506,0.186,0.307][0.506, 0.186, 0.307][0.506,0.186,0.307] | 平缓且有区分度 |
除以 161616 | [0.75,0.5,0.625][0.75, 0.5, 0.625][0.75,0.5,0.625] | [0.376,0.293,0.332][0.376, 0.293, 0.332][0.376,0.293,0.332] | 接近均匀 |
除以 162=25616^2=256162=256 | [0.0469,0.03125,0.0391][0.0469, 0.03125, 0.0391][0.0469,0.03125,0.0391] | [0.336,0.331,0.333][0.336, 0.331, 0.333][0.336,0.331,0.333] | 完全均匀 |
结论
-
不缩放:分布极端,梯度消失;
-
除以 dk\sqrt{d_k}dk:分布合理,注意力有效;
-
除以 dkd_kdk 或 dk2d_k^2dk2:分布均匀,注意力失效。
总结
注意力机制除以 dk\sqrt{d_k}dk 的本质是方差归一化:通过该操作将点积的方差固定为 1,避免因 dkd_kdk 增大导致的数值极端化,确保 softmax 输出平缓、梯度正常。若选择 dkd_kdk 或 dk2d_k^2dk2 作为缩放因子,会导致方差过小,注意力失去区分重要信息的能力,最终失效。