《深度学习实战》第6集:扩散模型(Diffusion Models)与高质量图像生成
《深度学习实战》第6集:扩散模型(Diffusion Models)与高质量图像生成
** 2025年3月1日 更新了代码案例,增加了Modelscope 平台大模型文生图代码,确保能实际运行上手**
扩散模型(Diffusion Models)是近年来图像生成领域的明星技术,凭借其高质量的生成能力和训练稳定性,逐渐成为 GAN 的有力竞争者。本集将带你深入了解扩散模型的核心思想,并通过 一个实战项目(代码运行通过) 使用预训练的 Stable Diffusion 模型生成艺术图像。我们还将探讨 MidJourney 和 DALL·E 等前沿技术背后的原理。
1. 扩散模型的核心思想
1.1 去噪过程与逆向生成
扩散模型的核心思想是通过逐步去噪的方式生成数据。它的灵感来源于物理学中的扩散过程:
- 正向扩散:将真实数据逐步加入噪声,直到变成完全随机的噪声。
- 逆向生成:从随机噪声出发,逐步去除噪声,恢复出逼真的数据。
数学表达
- 正向扩散:
-
在每一步中,向数据
x
t
添加高斯噪声,生成
x
t
+
1
。
在每一步中,向数据 x_t 添加高斯噪声,生成x_{t+1}。
在每一步中,向数据xt添加高斯噪声,生成xt+1。
q ( x t + 1 ∣ x t ) = N ( x t + 1 ; 1 − β t x t , β t I ) q(x_{t+1} | x_t) = \mathcal{N}(x_{t+1}; \sqrt{1-\beta_t}x_t, \beta_t I) q(xt+1∣xt)=N(xt+1;1−βtxt,βtI) - β t 是噪声强度,随着步骤增加而变化。 \beta_t 是噪声强度,随着步骤增加而变化。 βt是噪声强度,随着步骤增加而变化。
- 逆向生成:
-
通过神经网络学习逆向过程
p
θ
(
x
t
−
1
∣
x
t
)
,逐步去除噪声。
通过神经网络学习逆向过程 p_\theta(x_{t-1} | x_t),逐步去除噪声。
通过神经网络学习逆向过程pθ(xt−1∣xt),逐步去除噪声。
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
-
通过神经网络学习逆向过程
p
θ
(
x
t
−
1
∣
x
t
)
,逐步去除噪声。
通过神经网络学习逆向过程 p_\theta(x_{t-1} | x_t),逐步去除噪声。
通过神经网络学习逆向过程pθ(xt−1∣xt),逐步去除噪声。
1.2 DDPM(Denoising Diffusion Probabilistic Models)
DDPM 是扩散模型的经典实现,由 Ho 等人在 2020 年提出。它定义了以下关键点:
- 训练目标:最小化正向扩散和逆向生成之间的差异。
- 推理过程:从纯噪声开始,逐步生成清晰的图像。
改进版本
- DDIM(Denoising Diffusion Implicit Models):
- 提供更快的采样速度,减少生成时间。
- Latent Diffusion Models:
- 将扩散过程应用于潜在空间(Latent Space),降低计算成本。
2. 实战项目:使用预训练的 Stable Diffusion 模型生成艺术图像
我们将使用 Hugging Face 的 diffusers
库加载预训练的 Stable Diffusion 模型,并生成艺术图像。
- 注: Hugging Face 自2024年起高度不稳定不可用, 2.1代码仅作参考。
- 实际上手以 2.2 国内魔搭 Modelscope 代码案例为准
2.1从 Hugging Face 平台下载并运行文生图大模型 代码
安装依赖
确保安装以下库:
pip install diffusers transformers torch
加载预训练模型
from diffusers import StableDiffusionPipeline
import torch
# 加载预训练的 Stable Diffusion 模型
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
** 生成艺术图像**
# 输入提示词
prompt = "A futuristic cityscape with flying cars and neon lights"
# 生成图像
image = pipe(prompt).images[0]
# 保存图像
image.save("generated_image.png")
# 显示图像
image.show()
2.2 从国内 Modelscope 平台下载并运行文生图大模型 代码
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
import cv2
pipe = pipeline(task=Tasks.text_to_image_synthesis,
model='AI-ModelScope/stable-diffusion-v1-5',
#model=r'd:\llm_download\modelscope_cache/hub/models/AI-ModelScope/stable-diffusion-v1-5', #从本地读取大模型,必须先把 .cache 中的modelscope模型文件夹移动到指定D盘文件夹
model_revision='v1.0.9')
#cache_dir=r'd:\llm_download\modelscope_cache') #将模型缓存放到D盘
prompt = 'Two pandas playing chess,another panda is playing kongfu'
output = pipe({'text': prompt})
cv2.imwrite(prompt[0:10]+'result.png', output['output_imgs'][0])
from PIL import Image
# 显示图像
image = Image.open(prompt[0:10]+'result.png')
image.show()
以下是几个实际生成图片的示例,可以看到效果还是不错的,大模型文件一共:
prompt = ‘Shanghai Tower,river,sunset,high-tech city,futuristic,evening,night’
3. 图示:扩散模型的训练与推理流程
3.1 训练流程
- 正向扩散:将真实图像逐步添加噪声。
- 训练模型:学习如何从噪声中恢复原始图像。
3.2 推理流程
- 从噪声开始:生成一个随机噪声图像。
- 逐步去噪:通过逆向生成过程生成清晰图像。
4. 前沿关联:MidJourney 和 DALL·E 的背后技术
4.1 MidJourney
- 特点:
- 专注于艺术风格生成,支持高度定制化的提示词。
- 输出图像具有极高的视觉冲击力。
- 技术基础:
- 基于扩散模型和潜在空间优化。
4.2 DALL·E
- 特点:
- 能够根据文本描述生成多样化图像。
- 支持复杂的场景组合和细节控制。
- 技术基础:
- 结合扩散模型和 Transformer 架构。
5. 总结
扩散模型通过逐步去噪的方式实现了高质量的图像生成,其代表作 Stable Diffusion 已经在艺术创作领域大放异彩。通过实战项目,我们学会了如何使用预训练的 Stable Diffusion 模型生成艺术图像。同时,我们也探讨了 MidJourney 和 DALL·E 等前沿技术背后的原理。
附:Modelscope 代码案例要跑起来需要的 pip 依赖库清单:
absl-py==2.1.0
accelerate==1.4.0
addict==2.4.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.13
aiosignal==1.3.2
antlr4-python3-runtime==4.8
anyio==4.4.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
async-timeout==5.0.1
attrs==24.2.0
audioread==3.0.1
babel==2.16.0
beautifulsoup4==4.12.3
bitarray==3.1.0
bleach==6.1.0
certifi==2024.7.4
cffi==1.17.0
charset-normalizer==3.3.2
colorama==0.4.6
comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
Cython==3.0.12
datasets==3.3.2
debugpy==1.8.5
decorator==5.1.1
decord==0.6.0
defusedxml==0.7.1
diffusers==0.32.2
dill==0.3.8
exceptiongroup==1.2.2
executing==2.0.1
fairseq==0.12.2
fastjsonschema==2.20.0
filelock==3.17.0
flatbuffers==25.2.10
fonttools==4.56.0
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2024.12.0
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h11==0.14.0
h5py==3.13.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.29.1
hydra-core==1.0.7
idna==3.7
importlib_metadata==8.3.0
importlib_resources==6.5.2
ipykernel==6.29.5
ipython==8.18.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.4
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
keras==3.8.0
kiwisolver==1.4.7
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
llvmlite==0.43.0
lxml==5.3.1
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.4
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.4.1
modelscope==1.23.1
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
multiprocess==0.70.16
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.2.1
notebook_shim==0.2.4
numba==0.60.0
numpy==2.0.2
omegaconf==2.0.6
opencv-python==4.11.0.86
opt_einsum==3.4.0
optree==0.14.0
overrides==7.7.0
packaging==24.1
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
pi==0.1.2
pillow==11.1.0
platformdirs==4.2.2
pooch==1.8.2
portalocker==3.1.1
prometheus_client==0.20.0
prompt_toolkit==3.0.47
propcache==0.3.0
protobuf==5.29.3
psutil==6.0.0
pure_eval==0.2.3
pyarrow==19.0.1
pycparser==2.22
Pygments==2.18.0
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
pytz==2025.1
pywin32==306
pywinpty==2.0.13
PyYAML==6.0.2
pyzmq==26.1.1
RapidFuzz==3.12.1
referencing==0.35.1
regex==2024.11.6
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.9.4
rpds-py==0.20.0
sacrebleu==2.5.1
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.13.1
seaborn==0.13.2
Send2Trash==1.8.3
simplejson==3.20.1
six==1.16.0
sniffio==1.3.1
sortedcontainers==2.4.0
soundfile==0.13.1
soupsieve==2.6
soxr==0.5.0.post1
stack-data==0.6.3
sympy==1.13.1
tabulate==0.9.0
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
tensorflow==2.18.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow_intel==2.18.0
termcolor==2.5.0
terminado==0.18.1
tf_keras==2.18.0
threadpoolctl==3.5.0
timm==1.0.15
tinycss2==1.3.0
tokenizers==0.21.0
tomli==2.0.1
torch==2.5.1+cu121
torchaudio==2.5.1+cu121
torchvision==0.20.1+cu121
tornado==6.4.1
tqdm==4.67.1
traitlets==5.14.3
transformers==4.49.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.2
tzdata==2025.1
unicodedata2==16.0.0
uri-template==1.3.0
urllib3==2.2.2
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.18.3
zhconv==1.4.3
zipp==3.20.0