当前位置: 首页 > news >正文

扩散模型简介

扩散模型的基本原理

扩散模型(Diffusion Models)是一类生成模型,通过将数据逐渐加入噪声并学习逆向过程来生成新数据。其核心思想是模拟物理中的扩散过程,将数据分布逐渐转化为高斯分布,再通过学习逆向过程恢复原始数据分布。

扩散过程分为前向扩散和逆向扩散。前向扩散通过逐步添加高斯噪声破坏数据,最终使数据完全变为噪声。逆向扩散则通过学习噪声的逐步去除,从纯噪声中重建数据。

前向扩散过程

前向扩散过程是一个马尔可夫链,每一步根据固定方差调度添加高斯噪声。给定数据点 ( x_0 ),前向过程定义如下:

[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}) ]

其中 ( \beta_t ) 是噪声调度参数,控制每一步的噪声强度。通过重参数化技巧,可以直接从 ( x_0 ) 计算任意时间步 ( t ) 的噪声数据:

[ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon ]

其中 ( \alpha_t = 1 - \beta_t ),( \bar{\alpha}t = \prod{s=1}^t \alpha_s ),( \epsilon \sim \mathcal{N}(0, \mathbf{I}) )。

逆向扩散过程

逆向扩散过程通过学习一个神经网络 ( p_\theta(x_{t-1} | x_t) ) 逐步去噪。其目标是最大化对数似然的下界(ELBO),等价于最小化以下损失函数:

[ \mathcal{L}(\theta) = \mathbb{E}{t, x_0, \epsilon} \left[ | \epsilon - \epsilon\theta(x_t, t) |^2 \right] ]

其中 ( \epsilon_\theta ) 是噪声预测网络,通常采用U-Net结构。训练时,随机采样时间步 ( t ),用网络预测噪声并与真实噪声计算均方误差。

采样生成新数据

训练完成后,生成新数据的步骤如下:

  1. 从标准高斯分布采样初始噪声 ( x_T \sim \mathcal{N}(0, \mathbf{I}) )。
  2. 从 ( t=T ) 到 ( t=1 ) 逐步去噪: [ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}} \epsilon\theta(x_t, t) \right) + \sigma_t z ] 其中 ( z \sim \mathcal{N}(0, \mathbf{I}) ),( \sigma_t ) 是噪声方差。
  3. 最终得到生成数据 ( x_0 )。

改进与变体

扩散模型的性能依赖噪声调度和网络结构设计。常见改进包括:

  • DDPM(Denoising Diffusion Probabilistic Models):基础框架,采用线性噪声调度。
  • DDIM(Denoising Diffusion Implicit Models):通过非马尔可夫链加速采样。
  • Stable Diffusion:在潜空间进行扩散,降低计算成本。
  • Classifier Guidance:利用分类器梯度引导生成过程,提升生成质量。

应用场景

扩散模型广泛应用于图像生成、超分辨率、图像修复、文本到图像生成等领域。其高质量生成能力和稳定训练特性使其成为当前生成模型的重要方向。

代码示例(PyTorch)

以下是一个简化的扩散模型训练代码框架:

import torch
import torch.nn as nnclass DiffusionModel(nn.Module):def __init__(self, model, T, beta_start, beta_end):super().__init__()self.model = model  # 噪声预测网络self.T = Tself.betas = torch.linspace(beta_start, beta_end, T)self.alphas = 1 - self.betasself.alpha_bars = torch.cumprod(self.alphas, dim=0)def forward(self, x0, t, noise):alpha_bar_t = self.alpha_bars[t]xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noisepred_noise = self.model(xt, t)loss = torch.mean((noise - pred_noise) ** 2)return lossdef sample(self, shape):xt = torch.randn(shape)for t in reversed(range(self.T)):z = torch.randn(shape) if t > 0 else 0alpha_t = self.alphas[t]alpha_bar_t = self.alpha_bars[t]beta_t = self.betas[t]pred_noise = self.model(xt, torch.tensor([t]))xt = (xt - (beta_t / torch.sqrt(1 - alpha_bar_t)) * pred_noise) / torch.sqrt(alpha_t)xt += torch.sqrt(beta_t) * zreturn xt

http://www.dtcms.com/a/389907.html

相关文章:

  • [答疑]SysML模型的BDD中加了新的端口,怎样同步到IBD
  • MySQL 专题(二):索引原理与优化
  • 【脑电分析系列】第17篇:EEG特征提取与降维进阶 — 主成分分析、判别分析与黎曼几何
  • NVIDIA DOCA 环境产品使用与体验报告
  • C# Windows Service 中添加 log4net 的详细教程
  • 用 pymupdf4llm 打造 PDF → Markdown 的高效 LLM 数据管道(附实战对比)
  • 机械设备钢材建材网站 网站模版
  • Mysql8 SQLSTATE[42000] sql_mode=only_full_group_by错误解决办法
  • 【第五章:计算机视觉-项目实战之图像分类实战】2.图像分类实战-(3)批量归一化(Batch Normalization)和权重初始化的重要性
  • SQL Server 多用户读写随机超时?从问题分析到根治方案
  • 2.css的继承性,层叠性,优先级
  • OpenStack 学习笔记(四):编排管理与存储管理实践(上)
  • list_for_each_entry 详解
  • Perplexity AI Agent原生浏览器Comet
  • 颈椎按摩器方案开发,智能按摩仪方案设计
  • Sui 学习日志 1
  • 六、Java—IO流
  • 数据库 事务隔离级别 深入理解数据库事务隔离级别:脏读、不可重复读、幻读与串行化
  • 从“纸面”到“人本”:劳务合同管理的数字化蜕变
  • ARM架构——学习时钟7.2
  • VS Code 调试配置详解:占位符与语言差异
  • 锁 相关知识总结
  • caffeine 发生缓存内容被修改以及解决方案-深度克隆
  • rust编写web服务06-JWT身份认证
  • 《怪猎:荒野》制作人:PC平台对日本游戏非常重要
  • 大模型训练框架(二)FSDP
  • MySQL——系统数据库、常用工具
  • 蓝桥杯题目讲解_Python(转载)
  • 性能测试监控实践(九):性能测试时,监控docker微服务资源利用率和分析
  • TCP,UDP和ICMP