扩散模型 Diffusion Model 整体流程详解
🧠 Diffusion Model 思路、疑问和代码
文章目录
- 🧠 Diffusion Model 思路、疑问和代码
- 🔄 一、核心思想:从噪声到图像
- 📦 二、正向过程:加噪
- 🧠 三、反向过程:学习去噪
- 🔁 反向采样公式估计,引入 x 0 x_0 x0,用可解的 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} \mid x_t, x_0) q(xt−1∣xt,x0)
- 🤔 常见问题解析
- 🏋️♀️ 模型训练
- 🎨 推理 / 采样阶段
- 🧩 Diffusion 模型模板代码
- 1. 初始化和必要的导入
- 2. 定义 Beta Schedule 和相关函数
- 3. 定义 Diffusion 模型(UNet)
- 4. 扩散过程和去噪过程
- 5. 损失函数(训练时)
- 6. 数据加载和预处理
- 7. 训练循环
- 8. 采样(生成图像)
- 9. 显示图像
🔄 一、核心思想:从噪声到图像
扩散模型是一种生成模型,目标是从纯高斯噪声一步步生成真实图像。
它包含两个阶段:
阶段 | 方向 | 名称 | 做了什么 |
---|---|---|---|
正向 | x 0 → x T x_0 \to x_T x0→xT | Forward / Diffusion | 不断加噪,让图像变成随机噪声 |
反向 | x T → x 0 x_T \to x_0 xT→x0 | Reverse / Denoising | 学习去噪,还原出原始图像 |
📦 二、正向过程:加噪
我们从原图
x
0
x_0
x0 出发,在每个时刻
t
t
t 加入一点噪声,最终得到
x
T
x_T
xT,一个近似高斯噪声的图像:
x
t
=
α
ˉ
t
⋅
x
0
+
1
−
α
ˉ
t
⋅
ϵ
,
ϵ
∼
N
(
0
,
I
)
x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
xt=αˉt⋅x0+1−αˉt⋅ϵ,ϵ∼N(0,I)
- α ˉ t = ∏ i = 1 t α i , \bar{\alpha}_{t}=\prod_{i=1}^{t} \alpha_{i}, αˉt=∏i=1tαi, 其中 α i = 1 − β i \alpha_{i}=1-\beta_{i} αi=1−βi ( β i β_i βi 是第 i i i 步加的噪声强度), 从第 1 步到第 t 步累计保留的图像信息量
- t = 0 → T t = 0 \to T t=0→T,图像越来越模糊;
- 这个过程是可闭式计算的,无需一步步执行,一次公式就能生成 x t x_t xt ✅
🧠 三、反向过程:学习去噪
🎯 目标
从 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I) 出发,逐步去噪还原出 x 0 x_0 x0。
🤖 学什么?
我们训练一个神经网络 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 来预测=预测加噪时用的 ϵ \epsilon ϵ。
🔁 Trick:从 x t x_t xt 推出 x 0 x_0 x0,再推出 x t − 1 x_{t-1} xt−1
从正向公式:
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
⋅
ϵ
x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon
xt=αˉtx0+1−αˉt⋅ϵ
可以反解出:
x
0
=
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
⋅
ϵ
θ
(
x
t
,
t
)
)
x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t) \right)
x0=αˉt1(xt−1−αˉt⋅ϵθ(xt,t))
✅ 有了 ϵ θ \epsilon_\theta ϵθ,就能估计 x 0 x_0 x0,接着继续计算 x t − 1 x_{t-1} xt−1。
❓为什么不直接从 x t x_t xt 算 x t − 1 x_{t-1} xt−1?
虽然正向有:
x
t
=
f
(
x
t
−
1
)
+
noise
x_t = f(x_{t-1}) + \text{noise}
xt=f(xt−1)+noise
但反向是概率分布,不是函数。因为:
- 多个 x t − 1 x_{t-1} xt−1 可能加上不同噪声后变成同一个 x t x_t xt
- 所以 q ( x t − 1 ∣ x t ) q(x_{t-1} | x_t) q(xt−1∣xt) 是个复杂分布,没法显式表示
我们没有办法得到
q
(
x
t
−
1
)
q(x_{t-1})
q(xt−1)或
q
(
x
t
)
q(x_t)
q(xt)的明确表达式,因为它们涉及从
x
0
x_0
x0 积分过来的所有路径:
q
(
x
t
)
=
∫
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
)
d
x
t
−
1
q(x_t) = \int q(x_t \mid x_{t-1}) q(x_{t-1}) dx_{t-1}
q(xt)=∫q(xt∣xt−1)q(xt−1)dxt−1
但
q
(
x
t
−
1
)
q(x_{t-1})
q(xt−1)并不是一个简单的分布!因为它本身是从一系列有噪声扰动的步骤中一步步卷积出来的复杂分布,然后
q
(
x
t
−
2
)
q(x_{t-2})
q(xt−2)还要再由
q
(
x
t
−
3
)
q(x_{t-3})
q(xt−3) 推来……最终都依赖于
q
(
x
0
)
q(x_0)
q(x0),也就是原始数据分布。但!💥 我们根本不知道
q
(
x
0
)
q(x_0)
q(x0) 是什么!
因此,我们只能通过 x 0 x_0 x0估计它的均值和方差。
🔁 反向采样公式估计,引入 x 0 x_0 x0,用可解的 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} \mid x_t, x_0) q(xt−1∣xt,x0)
利用贝叶斯公式:
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
∣
x
0
)
p
(
x
t
∣
x
0
)
p\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, x_{0}\right)=\frac{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1}\right) p\left(\boldsymbol{x}_{t-1} \mid x_{0}\right)}{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}\right)}
p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1)p(xt−1∣x0)
式子中每一项都是可解的高斯分布,所以我们可以用条件高斯乘积公式,得到:
p
θ
(
x
t
−
1
∣
x
t
)
=
N
(
μ
θ
(
x
t
,
t
)
,
σ
t
2
I
)
p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta(x_t, t), \sigma_t^2 I)
pθ(xt−1∣xt)=N(μθ(xt,t),σt2I)
其中:
均值:
μ
θ
=
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
⋅
ϵ
θ
(
x
t
,
t
)
)
\mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right)
μθ=αt1(xt−1−αˉt1−αt⋅ϵθ(xt,t))
方差:
- 选择:方差
σ
t
2
\sigma_t^2
σt2设为常数(不训练),实验发现两种选择效果相似:
- σ t 2 = β t \sigma_t^2 = \beta_t σt2=βt(对应数据初始为高斯分布)
- σ t 2 = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \sigma_t^2 = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t σt2=1−αˉt1−αˉt−1⋅βt(对应数据初始为单点分布)
🤔 常见问题解析
Q1:为啥不直接训练一个学 x t − 1 x_{t-1} xt−1的模型?
- 空间太大,不容易收敛;
- 没法直接监督 x t − 1 x_{t-1} xt−1,但能监督 ϵ \epsilon ϵ
Q2:为什么不直接用预测的 x 0 x_0 x0 当作最终的生成结果?
- 预测的 x 0 x_0 x0 是近似值;
- 多个 x t x_t xt 推出的 x 0 x_0 x0 不一致;
- 扩散模型本质是一步步净化,不能一步到位。
Q3:为什么用 UNet 预测噪声 ϵ,而不是直接预测真实反向均值?
- 因为噪声 ϵ 的分布固定,预测更容易,训练更稳定;
- 通过数学推导(式10),发现可以改写为预测噪声ϵ的形式,计算更简单
🏋️♀️ 模型训练
训练时优化:
L
simple
=
E
x
0
,
t
,
ϵ
[
∥
ϵ
−
ϵ
θ
(
x
t
,
t
)
∥
2
]
\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, t, \epsilon} \left[ \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2 \right]
Lsimple=Ex0,t,ϵ[∥ϵ−ϵθ(xt,t)∥2]
流程如下:
- 从真实图像 x 0 x_0 x0 采样 t t t,加噪得 x t x_t xt
- 用网络预测噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t)
- 用MSE计算损失,与真实 ϵ \epsilon ϵ 对比
🎨 推理 / 采样阶段
从高斯噪声
x
T
x_T
xT 开始,按如下公式逐步采样直到
x
0
x_0
x0:
x
t
−
1
=
μ
θ
(
x
t
,
t
)
+
σ
t
⋅
z
,
z
∼
N
(
0
,
I
)
x_{t-1} = \mu_\theta(x_t, t) + \sigma_t \cdot z, \quad z \sim \mathcal{N}(0, I)
xt−1=μθ(xt,t)+σt⋅z,z∼N(0,I)
每步逻辑:
- 先预测噪声 ϵ θ \epsilon_\theta ϵθ
- 再估计 x 0 x_0 x0,推导 μ θ \mu_\theta μθ
- 加入随机噪声 z z z 得到 x t − 1 x_{t-1} xt−1
- 不断重复,最终得到生成图像!
🧩 Diffusion 模型模板代码
参考pytorch代码:GitHub - chunyu-li/ddpm: 扩散模型的简易 PyTorch 实现
1. 初始化和必要的导入
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
import random
2. 定义 Beta Schedule 和相关函数
# 线性 beta schedule(控制每步噪声的大小)
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
return torch.linspace(start, end, timesteps)
# 获取累积的 alpha 值
def get_alphas(betas):
return 1.0 - betas
# 获取累积 alpha 的乘积
def get_alphas_cumprod(alphas):
return torch.cumprod(alphas, axis=0)
# 计算反向噪声的标准差
def get_posterior_variance(alphas_cumprod, alphas_cumprod_prev, betas):
return betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
3. 定义 Diffusion 模型(UNet)
class SimpleUnet(nn.Module):
def __init__(self):
super(SimpleUnet, self).__init__()
# 这里定义一个简单的卷积网络作为示例,可以替换成更复杂的UNet
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(128 * 64 * 64, 256)
self.fc2 = nn.Linear(256, 3 * 64 * 64)
def forward(self, x, t):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # Flatten for fully connected layers
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x.view(x.size(0), 3, 64, 64) # Reshape back to image shape
4. 扩散过程和去噪过程
正向扩散过程:从 x 0 x_0 x0 到 x t x_t xt
def forward_diffusion_sample(x_0, t, betas, alphas_cumprod):
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1) # 广播至批次维度
sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
return x_t, noise
反向扩散过程:从 x T x_T xT 到 x 0 x_0 x0
def sample_timestep(x_t, t, model, alphas_cumprod, betas):
# 模型预测噪声
epsilon_pred = model(x_t, t)
# 计算当前时刻的均值和标准差
sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
posterior_variance_t = betas[t].view(-1, 1, 1, 1)
# 计算预测的x_0
x_0_pred = (x_t - sqrt_one_minus_alphas_cumprod_t * epsilon_pred) / sqrt_alphas_cumprod_t
# 反向采样
noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
x_t_minus_1 = sqrt_alphas_cumprod[t - 1] * x_0_pred + sqrt_one_minus_alphas_cumprod[t - 1] * noise
return x_t_minus_1
5. 损失函数(训练时)
def get_loss(model, x_0, t, betas, alphas_cumprod):
x_t, noise = forward_diffusion_sample(x_0, t, betas, alphas_cumprod)
noise_pred = model(x_t, t)
return F.mse_loss(noise, noise_pred) # MSE 损失
6. 数据加载和预处理
def load_transformed_dataset(img_size=64, batch_size=128):
data_transforms = [
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1), # [0,1] -> [-1,1]
]
data_transform = transforms.Compose(data_transforms)
train = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_train", transform=data_transform)
test = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_test", transform=data_transform)
dataset = torch.utils.data.ConcatDataset([train, test])
return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
7. 训练循环
if __name__ == "__main__":
# 初始化
model = SimpleUnet()
T = 300 # 扩散步数
betas = linear_beta_schedule(T)
alphas = get_alphas(betas)
alphas_cumprod = get_alphas_cumprod(alphas)
BATCH_SIZE = 128
epochs = 100
dataloader = load_transformed_dataset(batch_size=BATCH_SIZE)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
for batch_idx, (batch, _) in enumerate(dataloader):
optimizer.zero_grad()
batch = batch.to(device)
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
loss = get_loss(model, batch, t, betas, alphas_cumprod)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}")
8. 采样(生成图像)
def generate_samples(model, T=300):
# 从噪声开始
x_t = torch.randn((BATCH_SIZE, 3, 64, 64)).to(device)
for t in reversed(range(T)):
x_t = sample_timestep(x_t, t, model, alphas_cumprod, betas)
return x_t
9. 显示图像
def show_tensor_image(image):
image = image.squeeze().cpu().numpy().transpose(1, 2, 0)
image = (image + 1.0) / 2.0 # [-1, 1] -> [0, 1]
plt.imshow(image)
plt.axis('off')
plt.show()
x_t = sample_timestep(x_t, t, model, alphas_cumprod, betas)
return x_t
### 9. 显示图像
```python
def show_tensor_image(image):
image = image.squeeze().cpu().numpy().transpose(1, 2, 0)
image = (image + 1.0) / 2.0 # [-1, 1] -> [0, 1]
plt.imshow(image)
plt.axis('off')
plt.show()