当前位置: 首页 > news >正文

【AIGC】DDPM scheduler解析:扩散模型里的“调度器”到底在调什么?

扩散模型里的“调度器”到底在调什么?

—— 以 DDPM 仓库为例,一行行拆解 scheduler 预计算代码

如果你第一次接触扩散模型(Diffusion Model),最绕不开的一个词就是 scheduler。它到底在“调度”什么?为什么论文里一大堆 α、β、ᾱ,代码里又跑出来一堆 alphas_cumprodsqrt_recip_alphas

本文就带你把DDPM里的 scheduler 预计算代码彻底拆开,告诉你每一个张量背后对应的公式与直觉。读完你不仅能秒懂 forward_noising.py 在干什么,还能自己手写一个 scheduler。


1. 先给结论:scheduler 在“提前算好每一步的权重”

扩散模型的前向过程(加噪)和反向过程(去噪)都依赖大量随时间步 t 变化的系数
如果每一步都在现场算,训练/采样就会被拖垮。于是 DDPM 的做法是:

一次性把 0~T 步要用到的所有系数都算出来,放进张量。后面直接按 t 索引即可。

这些系数就是 scheduler 的全部工作。


2. 代码全景

以下代码位于 forward_noising.py,删掉了注释后不到 20 行,却把整个前向过程需要的张量全部算完:

import torch
import torch.nn.functional as F# 1. 线性 β 调度器
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):return torch.linspace(start, end, timesteps)T = 300
betas = linear_beta_schedule(T)# 2. 预计算所有中间量
alphas             = 1.0 - betas
alphas_cumprod     = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)sqrt_recip_alphas  = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

3. 逐行拆解:从符号到直觉

3.1 T = 300 —— 扩散“总步数”

把一张清晰图片变成纯高斯噪声,需要 300 步小步快跑。步数越多,单步扰动越小,数值稳定性越好。


3.2 betas —— 每一步加多少噪声

代码公式直觉
betas = linear_beta_schedule(T)βt\beta_tβt 线性从 0.0001 → 0.02控制第 t 步注入噪声的方差。β 越大,破坏越狠。

3.3 alphas ——“我还剩多少原图”

代码公式直觉
alphas = 1.0 - betasαt=1−βt\alpha_t = 1 - \beta_tαt=1βt原图保留比例。因为 β 很小,α 非常接近 1。

3.4 alphas_cumprod —— 一口气算到第 t 步的“累积保留率”

代码公式直觉
alphas_cumprod = torch.cumprod(alphas, 0)αˉt=∏i=1tαi\bar\alpha_t = \prod_{i=1}^{t}\alpha_iαˉt=i=1tαi如果你想直接x0x_0x0 跳到 xtx_txt,就靠它。DDPM 的“闭式采样”核心。

3.5 sqrt_alphas_cumprod & sqrt_one_minus_alphas_cumprod —— 前向公式里的“两根魔法棒”

代码公式直觉
sqrt_alphas_cumprodαˉt\sqrt{\bar\alpha_t}αˉt原图 x0x_0x0 的缩放系数
sqrt_one_minus_alphas_cumprod1−αˉt\sqrt{1 - \bar\alpha_t}1αˉt噪声 ε\varepsilonε 的缩放系数

把两者组合起来就是 DDPM 论文最经典的前向公式:

xt=αˉt x0+1−αˉt ε,ε∼N(0,I) x_t = \sqrt{\bar\alpha_t}\, x_0 + \sqrt{1-\bar\alpha_t}\, \varepsilon,\quad \varepsilon\sim\mathcal N(0,I) xt=αˉtx0+1αˉtε,εN(0,I)


3.6 alphas_cumprod_prev —— 上一步的 ᾱ

代码直觉
F.pad(..., value=1.0)αˉt−1\bar\alpha_{t-1}αˉt1 对齐到 t 的索引,方便后续向量运算。第 0 步之前补 1.0(αˉ0=1\bar\alpha_0=1αˉ0=1)。

3.7 sqrt_recip_alphas —— 反向去噪“放大器”

代码直觉
torch.sqrt(1.0 / alphas)在反向公式里把模型预测的 εθ(xt,t)\varepsilon_\theta(x_t,t)εθ(xt,t) 再乘回去,恢复信号。

3.8 posterior_variance —— 反向“再抖一点点”的方差

代码公式直觉
betas * (1 - ᾱ_prev) / (1 - ᾱ)β~t=1−αˉt−11−αˉtβt\tilde\beta_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_tβ~t=1αˉt1αˉt1βtxtx_txt 预测 xt−1x_{t-1}xt1 时,需要再采样一点点噪声,方差就是 β~t\tilde\beta_tβ~t。DDPM 论文里把它固定住,不交给网络学,简化训练。

4. 一张图总结

把上面所有张量按时间轴画出来(T=300):

时间步 tβ_tα_tᾱ_t√ᾱ_t√(1-ᾱ_t)β̃_t
01.0001.0000.000
10.00010.99990.99990.999950.01000.0001
3000.02000.9800~0.0020.0450.9990.0199

可以看到:

  • ᾱ_t 从 1 一路降到接近 0,解释“原图逐渐消失”。
  • √(1-ᾱ_t) 从 0 升到接近 1,解释“噪声逐渐占满”。
  • β̃_t 始终跟 β_t 差不多,但略小,保证反向过程的方差不会爆炸。

5. 小结 & 下一步

  1. scheduler 不是玄学,只是一次性把“每一步要乘的数”算好。
  2. 核心就 8 个张量,对应论文里 4 个公式,背下来就能手写扩散。
  3. 后面不管是 DDIM、PLMS 还是 DPM-Solver,都只是在换 β 的调度策略采样公式;预计算的思路一模一样。
http://www.dtcms.com/a/336470.html

相关文章:

  • 线程的同步
  • 魔改chromium源码——解除 iframe 的同源策略
  • Go语言实战案例-使用ORM框架 GORM 入门
  • 0️⃣基础 认识Python操作文件夹(初学者)
  • E2B是一个开源基础设施,允许您在云中安全隔离的沙盒中运行AI生成的代码和e2b.dev网站
  • 基因编辑预测工具:inDelphi与Pythia
  • Linux学习记录
  • 图解简单选择排序C语言实现
  • 01数据结构-插入排序
  • 一文读懂[特殊字符] LlamaFactory 中 Loss 曲线图
  • 防火墙带宽管理
  • 使用 Python 的 `cProfile` 分析函数执行时间
  • AUTOSAR进阶图解==>AUTOSAR_SWS_EthernetStateManager
  • 【PHP】Hyperf:接入 Nacos
  • 今日Java高频难点面试题推荐(2025年8月17日)
  • Python数据类型转换详解:从基础到实践
  • 【Kubernetes系列】Kubernetes中的resources
  • Matlab数字信号处理——ECG心电信号处理心率计算
  • FreeRTOS 中的守护任务(Daemon Task)
  • 第七十七章:多模态推理与生成——开启AI“从无到有”的时代!
  • 【C++知识杂记2】free和delete区别
  • c++--文件头注释/doxygen
  • Linux应用软件编程---多任务(线程)(线程创建、消亡、回收、属性、与进程的区别、线程间通信、函数指针)
  • 工作八年记
  • 官方正版在线安装office 365安装工具
  • 数组的三种主要声明方式
  • 大模型对齐算法(二): TDPO(Token-level Direct Preference Optimization)
  • Android中使用Compose实现各种样式Dialog
  • tcp会无限次重传吗
  • Eclipse Tomcat Configuration