网站模板框架公司网站建设平台
掩码图像建模 (MIM) 中的对数似然与交叉熵
1. 问题背景
在掩码图像建模(MIM)任务中,模型需要预测被遮蔽的图像块对应的视觉词元(可以理解为图像块的离散类别标签)。
具体来说:
- 每个被遮蔽的图像块 i ∈ M i \in M i∈M 的真实标签是 z i z_i zi(即它原本的视觉词元类别)。
- 模型通过 Transformer 编码器生成隐藏向量 h L i h_L^i hLi,然后通过一个分类器(参数为 W c , b c W_c, b_c Wc,bc)预测该位置的概率分布 p MIM ( z ′ ∣ x M ) p_{\text{MIM}}(z' | x^M) pMIM(z′∣xM)。
2. Softmax 分类器的作用
分类器的公式是:
p MIM ( z ′ ∣ x M ) = softmax z ( W c h L i + b c ) p_{\text{MIM}}(z' | x^M) = \text{softmax}_z(W_c h_L^i + b_c) pMIM(z′∣xM)=softmaxz(WchLi+bc)
- 输入:隐藏向量 h L i ∈ R D h_L^i \in \mathbb{R}^D hLi∈RD(来自 Transformer 的输出)。
- 参数:权重矩阵 W c ∈ R ∣ V ∣ × D W_c \in \mathbb{R}^{|\mathcal{V}| \times D} Wc∈R∣V∣×D 和偏置 b c ∈ R ∣ V ∣ b_c \in \mathbb{R}^{|\mathcal{V}|} bc∈R∣V∣,其中 ∣ V ∣ |\mathcal{V}| ∣V∣ 是视觉词元的总类别数。
- 输出:一个概率分布,表示模型认为被遮蔽块 i i i 属于每个视觉词元类别的概率。
具体计算步骤:
- 对每个被遮蔽位置 i i i,计算线性变换: W c h L i + b c W_c h_L^i + b_c WchLi+bc,得到一个长度为 ∣ V ∣ |\mathcal{V}| ∣V∣ 的向量(称为logits)。
- 对 logits 应用 softmax 函数,将其转换为概率分布:
p ( z ′ ) = exp ( logits [ z ′ ] ) ∑ k = 1 ∣ V ∣ exp ( logits [ k ] ) p(z') = \frac{\exp(\text{logits}[z'])}{\sum_{k=1}^{|\mathcal{V}|} \exp(\text{logits}[k])} p(z′)=∑k=1∣V∣exp(logits[k])exp(logits[z′])
其中 z ′ z' z′ 是某个可能的视觉词元类别。
3. 最大化对数似然(Maximize Log-Likelihood)
目标:让模型对真实标签 z i z_i zi 的预测概率尽可能高。
数学表达:
max θ E x ∼ D [ ∑ i ∈ M log p MIM ( z i ∣ x M ) ] \max_{\theta} \mathbb{E}_{x \sim \mathcal{D}} \left[ \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right] θmaxEx∼D[i∈M∑logpMIM(zi∣xM)]
- 解释:
- 对每个被遮蔽位置 i i i,计算真实标签 z i z_i zi 的对数概率 log p MIM ( z i ∣ x M ) \log p_{\text{MIM}}(z_i | x^M) logpMIM(zi∣xM)。
- 对所有被遮蔽位置求和,再对所有训练样本 x x x 求期望。
- 目标是最大化这个总和,即让模型对真实标签的预测概率尽可能大。
4. 交叉熵损失(Cross-Entropy Loss)
交叉熵损失是分类任务中常用的损失函数,定义为:
L CE = − ∑ i ∈ M log p MIM ( z i ∣ x M ) \mathcal{L}_{\text{CE}} = - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) LCE=−i∈M∑logpMIM(zi∣xM)
- 解释:
- 对每个被遮蔽位置 i i i,计算真实标签 z i z_i zi 的负对数概率。
- 对所有被遮蔽位置求和,得到总损失。
- 目标是最小化这个损失,即让真实标签的预测概率尽可能高。
5. 最大化对数似然 vs. 最小化交叉熵
关键结论:
最大化对数似然和最小化交叉熵损失是完全等价的!
具体来说:
max θ ∑ i ∈ M log p MIM ( z i ∣ x M ) ⟺ min θ ( − ∑ i ∈ M log p MIM ( z i ∣ x M ) ) \max_{\theta} \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \quad \iff \quad \min_{\theta} \left( - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right) θmaxi∈M∑logpMIM(zi∣xM)⟺θmin(−i∈M∑logpMIM(zi∣xM))
- 左边是最大化对数似然(使正确标签的概率最大化)。
- 右边是最小化交叉熵损失(使正确标签的负对数概率最小化)。
6. 为什么等价?
- 数学本质:交叉熵损失是负的对数似然。
- 对数似然是 ∑ log p \sum \log p ∑logp,交叉熵是 − ∑ log p -\sum \log p −∑logp。
- 最大化 A A A 等价于最小化 − A -A −A。
- 直观理解:
- 如果模型对真实标签的预测概率 p ( z i ) p(z_i) p(zi) 越大,对数似然 log p ( z i ) \log p(z_i) logp(zi) 越大,交叉熵损失 − log p ( z i ) -\log p(z_i) −logp(zi) 越小。
- 例如,若真实标签的概率 p ( z i ) = 0.9 p(z_i) = 0.9 p(zi)=0.9,则交叉熵损失为 − log ( 0.9 ) ≈ 0.11 -\log(0.9) \approx 0.11 −log(0.9)≈0.11;
若概率 p ( z i ) = 0.1 p(z_i) = 0.1 p(zi)=0.1,则损失为 − log ( 0.1 ) ≈ 2.30 -\log(0.1) \approx 2.30 −log(0.1)≈2.30。
显然,概率越大,损失越小。
7. 实际训练中的计算
在代码中,通常直接使用交叉熵损失函数(如 PyTorch 的 CrossEntropyLoss
):
# 假设 logits 是模型的输出(未经过 softmax)
# targets 是被遮蔽位置的真实视觉词元标签
loss = F.cross_entropy(logits, targets)
- 内部过程:
- 对 logits 应用 softmax,得到概率分布。
- 计算真实标签的负对数概率。
- 对所有样本和位置求平均,得到最终损失。
总结
- 目标:让模型对真实标签的预测概率尽可能高。
- 数学实现:通过最大化对数似然(等价于最小化交叉熵损失)。
- 代码实现:直接使用交叉熵损失函数,无需手动计算对数似然。