扩散模型的数学基础 —— 贝叶斯
- 贝叶斯定理基础
- 贝叶斯定理在条件扩散模型中的应用
- 利用贝叶斯定理分解
- 第一项:无条件得分项(prior score)
- 第二项:条件得分项(guidance term)
- 分类器引导(Classifier Guidance)
- DPS(Diffusion Posterior Sampling)
- 重建引导(Reconstruction Guidance)
- 总结:贝叶斯得分引导公式的作用
贝叶斯定理基础
给定两个事件 AAA 和 BBB,条件概率定义为:
P(A∣B)=P(A∩B)P(B),P(B∣A)=P(A∩B)P(A)P(A \mid B) = \frac{P(A \cap B)}{P(B)}, \quad P(B \mid A) = \frac{P(A \cap B)}{P(A)} P(A∣B)=P(B)P(A∩B),P(B∣A)=P(A)P(A∩B)
由此得:
P(A∩B)=P(A∣B)P(B)=P(B∣A)P(A)P(A \cap B) = P(A \mid B) P(B) = P(B \mid A) P(A) P(A∩B)=P(A∣B)P(B)=P(B∣A)P(A)
移项得贝叶斯定理:
P(A∣B)=P(B∣A)P(A)P(B)\boxed{P(A \mid B) = \frac{P(B \mid A) P(A)}{P(B)}} P(A∣B)=P(B)P(B∣A)P(A)
贝叶斯定理在条件扩散模型中的应用
在扩散模型中,设想我们希望生成满足某个条件 yyy 的样本 x0x_0x0。我们研究的是在某一时刻 ttt,扩散模型下的 条件分布的梯度(得分):
∇xtlogpt(xt∣y)\nabla_{x_t} \log p_t(x_t \mid y) ∇xtlogpt(xt∣y)
这是生成满足条件 yyy 的样本所需的梯度方向。
利用贝叶斯定理分解
对数形式的贝叶斯定理:
logpt(xt∣y)=logpt(xt)+logpt(y∣xt)−logpt(y)\log p_t(x_t \mid y) = \log p_t(x_t) + \log p_t(y \mid x_t) - \log p_t(y) logpt(xt∣y)=logpt(xt)+logpt(y∣xt)−logpt(y)
对 xtx_txt 求梯度时,注意 logpt(y)\log p_t(y)logpt(y) 与 xtx_txt 无关,因此它的梯度为 0:
∇xtlogpt(xt∣y)=∇xtlogpt(xt)+∇xtlogpt(y∣xt)\boxed{ \nabla_{x_t} \log p_t(x_t \mid y) = \nabla_{x_t} \log p_t(x_t) + \nabla_{x_t} \log p_t(y \mid x_t) } ∇xtlogpt(xt∣y)=∇xtlogpt(xt)+∇xtlogpt(y∣xt)
第一项:无条件得分项(prior score)
∇xtlogpt(xt)\nabla_{x_t} \log p_t(x_t) ∇xtlogpt(xt)
- 表示当前时刻 xtx_txt 下的 无条件得分函数;
- 通常由训练好的扩散模型(如噪声预测网络)直接提供;
- 是去噪方向的重要组成部分。
第二项:条件得分项(guidance term)
∇xtlogpt(y∣xt)\nabla_{x_t} \log p_t(y \mid x_t) ∇xtlogpt(y∣xt)
- 是使生成样本满足条件 yyy 的“引导项”;
- 可以使用不同策略近似或显式计算:
分类器引导(Classifier Guidance)
- 训练一个分类器 Cϕ(xt)≈pt(y∣xt)C_\phi(x_t) \approx p_t(y \mid x_t)Cϕ(xt)≈pt(y∣xt)
- 然后使用分类器对数输出的梯度作为引导:
∇xtlogpt(y∣xt)≈∇xtlogCϕ(xt)\nabla_{x_t} \log p_t(y \mid x_t) \approx \nabla_{x_t} \log C_\phi(x_t) ∇xtlogpt(y∣xt)≈∇xtlogCϕ(xt)
DPS(Diffusion Posterior Sampling)
- 不需要配对数据训练分类器;
- 假设已知无噪声数据下的条件概率 p(y∣x0)p(y \mid x_0)p(y∣x0)
- 使用扩散模型的 MMSE 估计 x^t≈E[x0∣xt]\hat{x}_t \approx \mathbb{E}[x_0 \mid x_t]x^t≈E[x0∣xt]
- 将梯度近似为:
∇xtlogpt(y∣xt)≈∇xtlogp(y∣x^t)\boxed{ \nabla_{x_t} \log p_t(y \mid x_t) \approx \nabla_{x_t} \log p(y \mid \hat{x}_t) } ∇xtlogpt(y∣xt)≈∇xtlogp(y∣x^t)
- 只要 p(y∣x0)p(y \mid x_0)p(y∣x0) 对 x0x_0x0 可微,这一项就对 xtx_txt 可导。
重建引导(Reconstruction Guidance)
DPS 的特例:假设 p(y∣x0)p(y \mid x_0)p(y∣x0) 为高斯分布,如:
p(y∣x0)=N(y;x0,σ2I)p(y \mid x_0) = \mathcal{N}(y; x_0, \sigma^2 I) p(y∣x0)=N(y;x0,σ2I)
则:
∇xtlogpt(y∣xt)≈x^t−yσ2⋅∂x^t∂xt\nabla_{x_t} \log p_t(y \mid x_t) \approx \frac{\hat{x}_t - y}{\sigma^2} \cdot \frac{\partial \hat{x}_t}{\partial x_t} ∇xtlogpt(y∣xt)≈σ2x^t−y⋅∂xt∂x^t
总结:贝叶斯得分引导公式的作用
最终,通过贝叶斯定理将 条件得分 分解为:
∇xtlogpt(xt∣y)=∇xtlogpt(xt)⏟模型本身+∇xtlogpt(y∣xt)⏟外部引导\nabla_{x_t} \log p_t(x_t \mid y) = \underbrace{\nabla_{x_t} \log p_t(x_t)}_{\text{模型本身}} + \underbrace{\nabla_{x_t} \log p_t(y \mid x_t)}_{\text{外部引导}} ∇xtlogpt(xt∣y)=模型本身∇xtlogpt(xt)+外部引导∇xtlogpt(y∣xt)
- 第一项由扩散模型训练得到;
- 第二项可以 用分类器、DPS 或其他方法近似;
- 总体目标是指导模型采样路径向着满足 yyy 的方向前进。