当前位置: 首页 > news >正文

扩散模型 Diffusion Model 整体流程详解

🧠 Diffusion Model 思路、疑问和代码

文章目录

🔄 一、核心思想:从噪声到图像

扩散模型是一种生成模型,目标是从纯高斯噪声一步步生成真实图像

它包含两个阶段:

阶段方向名称做了什么
正向 x 0 → x T x_0 \to x_T x0xTForward / Diffusion不断加噪,让图像变成随机噪声
反向 x T → x 0 x_T \to x_0 xTx0Reverse / Denoising学习去噪,还原出原始图像

📦 二、正向过程:加噪

我们从原图 x 0 x_0 x0 出发,在每个时刻 t t t 加入一点噪声,最终得到 x T x_T xT,一个近似高斯噪声的图像:
x t = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ , ϵ ∼ N ( 0 , I ) x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) xt=αˉt x0+1αˉt ϵ,ϵN(0,I)

  • α ˉ t = ∏ i = 1 t α i , \bar{\alpha}_{t}=\prod_{i=1}^{t} \alpha_{i}, αˉt=i=1tαi, 其中 α i = 1 − β i \alpha_{i}=1-\beta_{i} αi=1βi ( β i β_i βi 是第 i i i 步加的噪声强度), 从第 1 步到第 t 步累计保留的图像信息量
  • t = 0 → T t = 0 \to T t=0T,图像越来越模糊;
  • 这个过程是可闭式计算的,无需一步步执行,一次公式就能生成 x t x_t xt​ ✅

🧠 三、反向过程:学习去噪

🎯 目标

x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I) 出发,逐步去噪还原出 x 0 x_0 x0

🤖 学什么?

我们训练一个神经网络 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 来预测=预测加噪时用的 ϵ \epsilon ϵ


🔁 Trick:从 x t x_t xt 推出 x 0 x_0 x0,再推出 x t − 1 x_{t-1} xt1

从正向公式:
x t = α ˉ t x 0 + 1 − α ˉ t ⋅ ϵ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon xt=αˉt x0+1αˉt ϵ

可以反解出:
x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t) \right) x0=αˉt 1(xt1αˉt ϵθ(xt,t))

✅ 有了 ϵ θ \epsilon_\theta ϵθ,就能估计 x 0 x_0 x0,接着继续计算 x t − 1 x_{t-1} xt1


❓为什么不直接从 x t x_t xt x t − 1 x_{t-1} xt1

虽然正向有:
x t = f ( x t − 1 ) + noise x_t = f(x_{t-1}) + \text{noise} xt=f(xt1)+noise
但反向是概率分布,不是函数。因为:

  • 多个 x t − 1 x_{t-1} xt1 可能加上不同噪声后变成同一个 x t x_t xt
  • 所以 q ( x t − 1 ∣ x t ) q(x_{t-1} | x_t) q(xt1xt)​ 是个复杂分布,没法显式表示

我们没有办法得到 q ( x t − 1 ) q(x_{t-1}) q(xt1) q ( x t ) q(x_t) q(xt)的明确表达式,因为它们涉及从 x 0 x_0 x0 积分过来的所有路径:
q ( x t ) = ∫ q ( x t ∣ x t − 1 ) q ( x t − 1 ) d x t − 1 q(x_t) = \int q(x_t \mid x_{t-1}) q(x_{t-1}) dx_{t-1} q(xt)=q(xtxt1)q(xt1)dxt1
q ( x t − 1 ) q(x_{t-1}) q(xt1)并不是一个简单的分布!因为它本身是从一系列有噪声扰动的步骤中一步步卷积出来的复杂分布,然后 q ( x t − 2 ) q(x_{t-2}) q(xt2)还要再由 q ( x t − 3 ) q(x_{t-3}) q(xt3) 推来……最终都依赖于 q ( x 0 ) q(x_0) q(x0),也就是原始数据分布。但!💥 我们根本不知道 q ( x 0 ) q(x_0) q(x0) 是什么!

因此,我们只能通过 x 0 x_0 x0估计它的均值和方差。


🔁 反向采样公式估计,引入 x 0 x_0 x0,用可解的 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} \mid x_t, x_0) q(xt1xt,x0)

利用贝叶斯公式:
p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) p\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, x_{0}\right)=\frac{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1}\right) p\left(\boldsymbol{x}_{t-1} \mid x_{0}\right)}{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}\right)} p(xt1xt,x0)=p(xtx0)p(xtxt1)p(xt1x0)
式子中每一项都是可解的高斯分布,所以我们可以用条件高斯乘积公式,得到:
p θ ( x t − 1 ∣ x t ) = N ( μ θ ( x t , t ) , σ t 2 I ) p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta(x_t, t), \sigma_t^2 I) pθ(xt1xt)=N(μθ(xt,t),σt2I)
其中:

均值
μ θ = 1 α t ( x t − 1 − α t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right) μθ=αt 1(xt1αˉt 1αtϵθ(xt,t))
方差

  • 选择:方差 σ t 2 \sigma_t^2 σt2设为常数(不训练),实验发现两种选择效果相似:
    • σ t 2 = β t \sigma_t^2 = \beta_t σt2=βt(对应数据初始为高斯分布)
    • σ t 2 = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \sigma_t^2 = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t σt2=1αˉt1αˉt1βt(对应数据初始为单点分布)

🤔 常见问题解析

Q1:为啥不直接训练一个学 x t − 1 x_{t-1} xt1的模型?

  • 空间太大,不容易收敛;
  • 没法直接监督 x t − 1 x_{t-1} xt1,但能监督 ϵ \epsilon ϵ

Q2:为什么不直接用预测的 x 0 x_0 x0 当作最终的生成结果?

  • 预测的 x 0 x_0 x0 是近似值;
  • 多个 x t x_t xt 推出的 x 0 x_0 x0 不一致;
  • 扩散模型本质是一步步净化,不能一步到位。

Q3:为什么用 UNet 预测噪声 ϵ,而不是直接预测真实反向均值?

  • 因为噪声 ϵ 的分布固定,预测更容易,训练更稳定;
  • 通过数学推导(式10),发现可以改写为预测噪声ϵ的形式,计算更简单

🏋️‍♀️ 模型训练

训练时优化:
L simple = E x 0 , t , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] \mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, t, \epsilon} \left[ \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2 \right] Lsimple=Ex0,t,ϵ[ϵϵθ(xt,t)2]

流程如下:

  1. 从真实图像 x 0 x_0 x0 采样 t t t,加噪得 x t x_t xt
  2. 用网络预测噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t)
  3. 用MSE计算损失,与真实 ϵ \epsilon ϵ 对比

🎨 推理 / 采样阶段

从高斯噪声 x T x_T xT 开始,按如下公式逐步采样直到 x 0 x_0 x0
x t − 1 = μ θ ( x t , t ) + σ t ⋅ z , z ∼ N ( 0 , I ) x_{t-1} = \mu_\theta(x_t, t) + \sigma_t \cdot z, \quad z \sim \mathcal{N}(0, I) xt1=μθ(xt,t)+σtz,zN(0,I)
每步逻辑:

  • 先预测噪声 ϵ θ \epsilon_\theta ϵθ
  • 再估计 x 0 x_0 x0,推导 μ θ \mu_\theta μθ
  • 加入随机噪声 z z z 得到 x t − 1 x_{t-1} xt1
  • 不断重复,最终得到生成图像!

🧩 Diffusion 模型模板代码

参考pytorch代码:GitHub - chunyu-li/ddpm: 扩散模型的简易 PyTorch 实现

1. 初始化和必要的导入

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
import random

2. 定义 Beta Schedule 和相关函数

# 线性 beta schedule(控制每步噪声的大小)
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

# 获取累积的 alpha 值
def get_alphas(betas):
    return 1.0 - betas

# 获取累积 alpha 的乘积
def get_alphas_cumprod(alphas):
    return torch.cumprod(alphas, axis=0)

# 计算反向噪声的标准差
def get_posterior_variance(alphas_cumprod, alphas_cumprod_prev, betas):
    return betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

3. 定义 Diffusion 模型(UNet)

class SimpleUnet(nn.Module):
    def __init__(self):
        super(SimpleUnet, self).__init__()
        # 这里定义一个简单的卷积网络作为示例,可以替换成更复杂的UNet
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 64 * 64, 256)
        self.fc2 = nn.Linear(256, 3 * 64 * 64)

    def forward(self, x, t):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten for fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.view(x.size(0), 3, 64, 64)  # Reshape back to image shape

4. 扩散过程和去噪过程

正向扩散过程:从 x 0 x_0 x0 x t x_t xt
def forward_diffusion_sample(x_0, t, betas, alphas_cumprod):
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)  # 广播至批次维度
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
    
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    return x_t, noise
反向扩散过程:从 x T x_T xT x 0 x_0 x0
def sample_timestep(x_t, t, model, alphas_cumprod, betas):
    # 模型预测噪声
    epsilon_pred = model(x_t, t)

    # 计算当前时刻的均值和标准差
    sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
    posterior_variance_t = betas[t].view(-1, 1, 1, 1)

    # 计算预测的x_0
    x_0_pred = (x_t - sqrt_one_minus_alphas_cumprod_t * epsilon_pred) / sqrt_alphas_cumprod_t

    # 反向采样
    noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
    x_t_minus_1 = sqrt_alphas_cumprod[t - 1] * x_0_pred + sqrt_one_minus_alphas_cumprod[t - 1] * noise
    return x_t_minus_1

5. 损失函数(训练时)

def get_loss(model, x_0, t, betas, alphas_cumprod):
    x_t, noise = forward_diffusion_sample(x_0, t, betas, alphas_cumprod)
    noise_pred = model(x_t, t)
    return F.mse_loss(noise, noise_pred)  # MSE 损失

6. 数据加载和预处理

def load_transformed_dataset(img_size=64, batch_size=128):
    data_transforms = [
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1),  # [0,1] -> [-1,1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_train", transform=data_transform)
    test = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_test", transform=data_transform)

    dataset = torch.utils.data.ConcatDataset([train, test])
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

7. 训练循环

if __name__ == "__main__":
    # 初始化
    model = SimpleUnet()
    T = 300  # 扩散步数
    betas = linear_beta_schedule(T)
    alphas = get_alphas(betas)
    alphas_cumprod = get_alphas_cumprod(alphas)

    BATCH_SIZE = 128
    epochs = 100

    dataloader = load_transformed_dataset(batch_size=BATCH_SIZE)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        for batch_idx, (batch, _) in enumerate(dataloader):
            optimizer.zero_grad()
            batch = batch.to(device)

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch, t, betas, alphas_cumprod)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}")

8. 采样(生成图像)

def generate_samples(model, T=300):
    # 从噪声开始
    x_t = torch.randn((BATCH_SIZE, 3, 64, 64)).to(device)
    
    for t in reversed(range(T)):
        x_t = sample_timestep(x_t, t, model, alphas_cumprod, betas)
    
    return x_t

9. 显示图像

def show_tensor_image(image):
    image = image.squeeze().cpu().numpy().transpose(1, 2, 0)
    image = (image + 1.0) / 2.0  # [-1, 1] -> [0, 1]
    plt.imshow(image)
    plt.axis('off')
    plt.show()

x_t = sample_timestep(x_t, t, model, alphas_cumprod, betas)

return x_t

### 9. 显示图像

```python
def show_tensor_image(image):
    image = image.squeeze().cpu().numpy().transpose(1, 2, 0)
    image = (image + 1.0) / 2.0  # [-1, 1] -> [0, 1]
    plt.imshow(image)
    plt.axis('off')
    plt.show()

相关文章:

  • 我拿Cursor复现了Manus的效果
  • 上层 Makefile 控制下层 Makefile ---- 第二部分(补充一些例子与细节)
  • URL结构、HTTP协议报文
  • Redis for Windows 后台服务运行
  • 【6】深入学习http模块(万字)-Nodejs开发入门
  • javascript专题2 ---- 在 JavaScript 列表(数组)的第一个位置插入数据
  • 【Linux C】简单bash设计
  • 重返JAVA之路——面向对象
  • 论文:Generalized Category Discovery with Large Language Models in the Loop
  • 玩转ChatGPT:使用深入研究功能梳理思路
  • 最大公约数和最小倍数 java
  • 【Linux实践系列】:匿名管道收尾+完善shell外壳程序
  • redis linux 安装简单教程(redis 3.0.4)
  • Spring Boot(二十一):RedisTemplate的String和Hash类型操作
  • 基于XGBoost的异烟酸生产收率预测:冠军解决方案解析
  • 七大寻址方式
  • ubuntu 系统安装Mysql
  • 【代码安全】spotbugs编写自定义规则(一) 快速开始
  • 【数据可视化艺术·实战篇】视频AI+人流可视化:如何让数据“动”起来?
  • 每日OJ_牛客_ruby和薯条_排序+二分/滑动窗口_C++_Java
  • 以开放促发展,以发展促开放,浙江加快建设高能级开放强省
  • 读懂城市|成都高新区:打造“人尽其才”的“理想之城”
  • 私家车跑“顺风”出事故,意外险赔不赔?
  • 法律顾问被控配合他人诈骗酒店资产一审判8年,二审辩称无罪
  • 孟夏韵评《无序的学科》丨误读与重构的文化漂流
  • 美联储计划裁员约10%