扩散模型简介
扩散模型的基本原理
扩散模型(Diffusion Models)是一类生成模型,通过将数据逐渐加入噪声并学习逆向过程来生成新数据。其核心思想是模拟物理中的扩散过程,将数据分布逐渐转化为高斯分布,再通过学习逆向过程恢复原始数据分布。
扩散过程分为前向扩散和逆向扩散。前向扩散通过逐步添加高斯噪声破坏数据,最终使数据完全变为噪声。逆向扩散则通过学习噪声的逐步去除,从纯噪声中重建数据。
前向扩散过程
前向扩散过程是一个马尔可夫链,每一步根据固定方差调度添加高斯噪声。给定数据点 ( x_0 ),前向过程定义如下:
[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}) ]
其中 ( \beta_t ) 是噪声调度参数,控制每一步的噪声强度。通过重参数化技巧,可以直接从 ( x_0 ) 计算任意时间步 ( t ) 的噪声数据:
[ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon ]
其中 ( \alpha_t = 1 - \beta_t ),( \bar{\alpha}t = \prod{s=1}^t \alpha_s ),( \epsilon \sim \mathcal{N}(0, \mathbf{I}) )。
逆向扩散过程
逆向扩散过程通过学习一个神经网络 ( p_\theta(x_{t-1} | x_t) ) 逐步去噪。其目标是最大化对数似然的下界(ELBO),等价于最小化以下损失函数:
[ \mathcal{L}(\theta) = \mathbb{E}{t, x_0, \epsilon} \left[ | \epsilon - \epsilon\theta(x_t, t) |^2 \right] ]
其中 ( \epsilon_\theta ) 是噪声预测网络,通常采用U-Net结构。训练时,随机采样时间步 ( t ),用网络预测噪声并与真实噪声计算均方误差。
采样生成新数据
训练完成后,生成新数据的步骤如下:
- 从标准高斯分布采样初始噪声 ( x_T \sim \mathcal{N}(0, \mathbf{I}) )。
- 从 ( t=T ) 到 ( t=1 ) 逐步去噪: [ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}} \epsilon\theta(x_t, t) \right) + \sigma_t z ] 其中 ( z \sim \mathcal{N}(0, \mathbf{I}) ),( \sigma_t ) 是噪声方差。
- 最终得到生成数据 ( x_0 )。
改进与变体
扩散模型的性能依赖噪声调度和网络结构设计。常见改进包括:
- DDPM(Denoising Diffusion Probabilistic Models):基础框架,采用线性噪声调度。
- DDIM(Denoising Diffusion Implicit Models):通过非马尔可夫链加速采样。
- Stable Diffusion:在潜空间进行扩散,降低计算成本。
- Classifier Guidance:利用分类器梯度引导生成过程,提升生成质量。
应用场景
扩散模型广泛应用于图像生成、超分辨率、图像修复、文本到图像生成等领域。其高质量生成能力和稳定训练特性使其成为当前生成模型的重要方向。
代码示例(PyTorch)
以下是一个简化的扩散模型训练代码框架:
import torch
import torch.nn as nnclass DiffusionModel(nn.Module):def __init__(self, model, T, beta_start, beta_end):super().__init__()self.model = model # 噪声预测网络self.T = Tself.betas = torch.linspace(beta_start, beta_end, T)self.alphas = 1 - self.betasself.alpha_bars = torch.cumprod(self.alphas, dim=0)def forward(self, x0, t, noise):alpha_bar_t = self.alpha_bars[t]xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noisepred_noise = self.model(xt, t)loss = torch.mean((noise - pred_noise) ** 2)return lossdef sample(self, shape):xt = torch.randn(shape)for t in reversed(range(self.T)):z = torch.randn(shape) if t > 0 else 0alpha_t = self.alphas[t]alpha_bar_t = self.alpha_bars[t]beta_t = self.betas[t]pred_noise = self.model(xt, torch.tensor([t]))xt = (xt - (beta_t / torch.sqrt(1 - alpha_bar_t)) * pred_noise) / torch.sqrt(alpha_t)xt += torch.sqrt(beta_t) * zreturn xt