【AIGC】DDPM scheduler解析:扩散模型里的“调度器”到底在调什么?
扩散模型里的“调度器”到底在调什么?
—— 以 DDPM 仓库为例,一行行拆解 scheduler 预计算代码
如果你第一次接触扩散模型(Diffusion Model),最绕不开的一个词就是 scheduler。它到底在“调度”什么?为什么论文里一大堆 α、β、ᾱ,代码里又跑出来一堆
alphas_cumprod
、sqrt_recip_alphas
?本文就带你把DDPM里的 scheduler 预计算代码彻底拆开,告诉你每一个张量背后对应的公式与直觉。读完你不仅能秒懂
forward_noising.py
在干什么,还能自己手写一个 scheduler。
1. 先给结论:scheduler 在“提前算好每一步的权重”
扩散模型的前向过程(加噪)和反向过程(去噪)都依赖大量随时间步 t 变化的系数。
如果每一步都在现场算,训练/采样就会被拖垮。于是 DDPM 的做法是:
一次性把 0~T 步要用到的所有系数都算出来,放进张量。后面直接按 t 索引即可。
这些系数就是 scheduler 的全部工作。
2. 代码全景
以下代码位于 forward_noising.py
,删掉了注释后不到 20 行,却把整个前向过程需要的张量全部算完:
import torch
import torch.nn.functional as F# 1. 线性 β 调度器
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):return torch.linspace(start, end, timesteps)T = 300
betas = linear_beta_schedule(T)# 2. 预计算所有中间量
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
3. 逐行拆解:从符号到直觉
3.1 T = 300
—— 扩散“总步数”
把一张清晰图片变成纯高斯噪声,需要 300 步小步快跑。步数越多,单步扰动越小,数值稳定性越好。
3.2 betas
—— 每一步加多少噪声
代码 | 公式 | 直觉 |
---|---|---|
betas = linear_beta_schedule(T) | βt\beta_tβt 线性从 0.0001 → 0.02 | 控制第 t 步注入噪声的方差。β 越大,破坏越狠。 |
3.3 alphas
——“我还剩多少原图”
代码 | 公式 | 直觉 |
---|---|---|
alphas = 1.0 - betas | αt=1−βt\alpha_t = 1 - \beta_tαt=1−βt | 原图保留比例。因为 β 很小,α 非常接近 1。 |
3.4 alphas_cumprod
—— 一口气算到第 t 步的“累积保留率”
代码 | 公式 | 直觉 |
---|---|---|
alphas_cumprod = torch.cumprod(alphas, 0) | αˉt=∏i=1tαi\bar\alpha_t = \prod_{i=1}^{t}\alpha_iαˉt=∏i=1tαi | 如果你想直接从 x0x_0x0 跳到 xtx_txt,就靠它。DDPM 的“闭式采样”核心。 |
3.5 sqrt_alphas_cumprod
& sqrt_one_minus_alphas_cumprod
—— 前向公式里的“两根魔法棒”
代码 | 公式 | 直觉 |
---|---|---|
sqrt_alphas_cumprod | αˉt\sqrt{\bar\alpha_t}αˉt | 原图 x0x_0x0 的缩放系数 |
sqrt_one_minus_alphas_cumprod | 1−αˉt\sqrt{1 - \bar\alpha_t}1−αˉt | 噪声 ε\varepsilonε 的缩放系数 |
把两者组合起来就是 DDPM 论文最经典的前向公式:
xt=αˉt x0+1−αˉt ε,ε∼N(0,I) x_t = \sqrt{\bar\alpha_t}\, x_0 + \sqrt{1-\bar\alpha_t}\, \varepsilon,\quad \varepsilon\sim\mathcal N(0,I) xt=αˉtx0+1−αˉtε,ε∼N(0,I)
3.6 alphas_cumprod_prev
—— 上一步的 ᾱ
代码 | 直觉 |
---|---|
F.pad(..., value=1.0) | 把 αˉt−1\bar\alpha_{t-1}αˉt−1 对齐到 t 的索引,方便后续向量运算。第 0 步之前补 1.0(αˉ0=1\bar\alpha_0=1αˉ0=1)。 |
3.7 sqrt_recip_alphas
—— 反向去噪“放大器”
代码 | 直觉 |
---|---|
torch.sqrt(1.0 / alphas) | 在反向公式里把模型预测的 εθ(xt,t)\varepsilon_\theta(x_t,t)εθ(xt,t) 再乘回去,恢复信号。 |
3.8 posterior_variance
—— 反向“再抖一点点”的方差
代码 | 公式 | 直觉 |
---|---|---|
betas * (1 - ᾱ_prev) / (1 - ᾱ) | β~t=1−αˉt−11−αˉtβt\tilde\beta_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_tβ~t=1−αˉt1−αˉt−1βt | 从 xtx_txt 预测 xt−1x_{t-1}xt−1 时,需要再采样一点点噪声,方差就是 β~t\tilde\beta_tβ~t。DDPM 论文里把它固定住,不交给网络学,简化训练。 |
4. 一张图总结
把上面所有张量按时间轴画出来(T=300):
时间步 t | β_t | α_t | ᾱ_t | √ᾱ_t | √(1-ᾱ_t) | β̃_t |
---|---|---|---|---|---|---|
0 | — | — | 1.000 | 1.000 | 0.000 | — |
1 | 0.0001 | 0.9999 | 0.9999 | 0.99995 | 0.0100 | 0.0001 |
… | … | … | … | … | … | … |
300 | 0.0200 | 0.9800 | ~0.002 | 0.045 | 0.999 | 0.0199 |
可以看到:
- ᾱ_t 从 1 一路降到接近 0,解释“原图逐渐消失”。
- √(1-ᾱ_t) 从 0 升到接近 1,解释“噪声逐渐占满”。
- β̃_t 始终跟 β_t 差不多,但略小,保证反向过程的方差不会爆炸。
5. 小结 & 下一步
- scheduler 不是玄学,只是一次性把“每一步要乘的数”算好。
- 核心就 8 个张量,对应论文里 4 个公式,背下来就能手写扩散。
- 后面不管是 DDIM、PLMS 还是 DPM-Solver,都只是在换 β 的调度策略 和 采样公式;预计算的思路一模一样。