ZigMa:一种DiT风格的Zigzag Mamba扩散模型
摘要
本文聚焦于扩散模型的优化,提出ZigMa(Zigzag Mamba)扩散模型。该模型借助Mamba的长序列建模能力,针对传统扩散模型的可扩展性和复杂度问题进行改进。通过创新的Zigzag Mamba块和随机插值框架的结合,在不同分辨率的图像和视频数据处理上展现出优异性能,为扩散模型的发展提供了新的思路和方法。
关键词:扩散模型;状态空间模型;Zigzag Mamba;随机插值;多维度建模
一、引言
扩散模型在图像处理、视频分析、点云处理等众多领域取得了显著进展,不少模型基于潜在扩散模型(LDM)构建,以UNet为骨干网络。然而,可扩展性问题一直制约着LDM的发展。基于Transformer的结构虽然在可扩展性和多模态训练方面有优势,但注意力机制的二次复杂度成为扩散模型的瓶颈。状态空间模型在长序列建模上潜力巨大,Mamba作为其中的代表,通过高效并行扫描等技术优化长序列建模。但将Mamba拓展到二维图像和三维视频时面临挑战,现有方法存在忽略空间连续性或增加参数负担等问题。
因此,本文提出Zigzag Mamba(ZigMa),旨在解决这些问题,提升扩散模型性能,并探索随机插值在大规模图像和视频数据中的应用。
二、相关工作
2.1 Mamba
Mamba是新型状态空间模型,在医学成像、图像恢复、自然语言处理等多个领域有广泛应用。与之相关的研究如VisionMamba、S4ND和Mamba-ND等,各有特点。VisionMamba在判别任务中使用双向SSM,计算成本高;S4ND在Mamba推理中引入局部卷积;Mamba-ND在判别任务中考虑多维性。而本文重点在于将扫描复杂度分摊到网络各层,在零参数负担下融入视觉数据归纳偏差。
2.2 扩散模型中的骨干网络
扩散模型常用的骨干网络有UNet和ViT。UNet内存需求大,ViT虽可扩展性强且利于多模态学习,但二次复杂度限制了视觉令牌处理。本文受Mamba启发,探索基于SSM的扩散骨干网络,与DiffSSM、DIS的研究重点不同,更关注骨干网络设计及在复杂视觉数据中的应用。扩散模型常用的骨干网络有UNet和ViT。UNet内存需求大,ViT虽可扩展性强且利于多模态学习,但二次复杂度限制了视觉令牌处理。本文受Mamba启发,探索基于SSM的扩散骨干网络,与DiffSSM、DIS的研究重点不同,更关注骨干网络设计及在复杂视觉数据中的应用。
2.3 扩散模型中的随机微分方程(SDE)和常微分方程(ODE)
基于分数的生成模型依赖SDE,像SMLD和DDPMs。近期研究发现,使用ODE采样器处理扩散SDE可降低采样成本。在随机插值框架中,SiT模型研究了小分辨率下插值方法的作用,本文则将其扩展到高分辨率二维图像和三维视频数据。
三、方法
3.1 背景:状态空间模型
状态空间模型(SSMs)能有效处理长距离依赖关系,计算复杂度与序列长度呈线性关系。其一般形式为:
{
x
′
(
t
)
=
A
(
t
)
x
(
t
)
+
B
(
t
)
u
(
t
)
y
(
t
)
=
C
(
t
)
x
(
t
)
+
D
(
t
)
u
(
t
)
\begin{cases} x'(t)=A(t)x(t)+B(t)u(t)\\ y(t)=C(t)x(t)+D(t)u(t) \end{cases}
{x′(t)=A(t)x(t)+B(t)u(t)y(t)=C(t)x(t)+D(t)u(t)
Mamba放宽了SSM参数的时间不变性约束,提升灵活性的同时保持计算效率。本文聚焦于Mamba在扩散模型中的扫描方案,挖掘多维视觉数据的归纳偏差。
3.2 扩散骨干网络:Zigzag Mamba
3.2.1 DiT风格的网络
选择基于AdaLN的ViT框架构建Mamba网络,核心组件为Zigzag扫描。
ZigMa的骨干网络采用了类似于DiT 的架构,共由L层构成。我们将单扫描Mamba块用作跨不同图像块的主要推理模块。为确保网络具备位置感知能力,我们基于单扫描Mamba设计了一种排列 - 重排列方案。不同的层遵循唯一的重排列操作Ω和反向重排列操作Ω¯对,从而优化了该方法的位置感知能力。
3.2.2 Mamba中的Zigzag扫描
以往方法增加扫描方向会导致内存问题。本文提出在输入前向扫描块前重排令牌,设计公式如下:
{
z
Ω
i
=
arrange
(
z
i
,
Ω
i
)
z
‾
Ω
i
=
scan
(
z
Ω
)
z
i
+
1
=
arrange
(
z
‾
Ω
i
,
Ω
‾
i
)
\begin{cases} z_{\Omega_{i}}=\text{arrange}(z_{i},\Omega_{i})\\ \overline{z}_{\Omega_{i}}=\text{scan}(z_{\Omega})\\ z_{i + 1}=\text{arrange}(\overline{z}_{\Omega_{i}},\overline{\Omega}_{i}) \end{cases}
⎩
⎨
⎧zΩi=arrange(zi,Ωi)zΩi=scan(zΩ)zi+1=arrange(zΩi,Ωi)
同时,强调空间连续性和空间填充,设计了八种空间填充连续方案,用
Ω
i
=
S
i
%
8
\Omega_{i}=S_{{i\%8}}
Ωi=Si%8 表示每一层的方案。
二维图像扫描。我们的Mamba扫描设计基于子图(a)所示的扫描(sweep-scan)方案。在此基础上,我们开发了子图(b)所示的之字形扫描(zigzag-scan)方案,以增强图像块的连续性,从而最大限度地发挥Mamba块的潜力。由于这些连续扫描存在多种可能的排列方式,在子图(c)中列出了八种最常见的之字形扫描方式。
3.2.3 在Zigzag Mamba上部署文本条件
为解决Mamba在文本条件应用中的不足,设计了基于Mamba块的交叉注意力块,实现长序列建模和多令牌条件(如文本条件),还具备可解释性。
3.2.4 通过分解空间和时间信息推广到三维视频
提出三种视频Mamba块变体:sweep - scan直接展平3D特征,不考虑连续性;3D Zigzag尝试保持2D和3D连续性,但优化效果不佳;Factorized 3D Zigzag将空间和时间相关性分解为单独的Mamba块,效果较好。计算分析表明,Zigzag Mamba相比全局自注意力和k - direction mamba,复杂度更低。计算复杂度公式如下:
ζ
(
self - attention
)
=
4
M
D
2
+
2
M
2
D
\zeta(\text{self - attention}) = 4MD^{2} + 2M^{2}D
ζ(self - attention)=4MD2+2M2D
ζ
(
k
−
m
a
m
b
a
)
=
k
×
[
3
M
(
2
D
)
N
+
M
(
2
D
)
N
2
]
\zeta(k - mamba)=k×[3M(2D)N + M(2D)N^{2}]
ζ(k−mamba)=k×[3M(2D)N+M(2D)N2]
ζ
(
zigzag
)
=
3
M
(
2
D
)
N
+
M
(
2
D
)
N
2
\zeta(\text{zigzag}) = 3M(2D)N + M(2D)N^{2}
ζ(zigzag)=3M(2D)N+M(2D)N2
3.3 扩散框架:随机插值
3.3.1 基于向量v和分数s的采样
依据相关研究,
(
x
t
)
(x_{t})
(xt)的时间相关概率分布
(
p
t
(
x
)
)
(p_{t}(x))
(pt(x))与反向时间SDE的分布一致,公式为:
d
X
t
=
v
(
X
t
,
t
)
d
t
+
1
2
w
t
s
(
X
t
,
t
)
d
t
+
w
t
d
W
‾
t
dX_{t}=v(X_{t},t)dt+\frac{1}{2}w_{t}s(X_{t},t)dt+\sqrt{w_{t}}d\overline{W}_{t}
dXt=v(Xt,t)dt+21wts(Xt,t)dt+wtdWt
只要估计出速度(v(x,t))和/或分数(s(x,t))场,就能用于采样。通过概率流ODE或反向时间SDE,从(
X
T
=
ε
N
(
0
,
I
)
X_{T}=\varepsilon ~ N(0, I)
XT=ε N(0,I))反向求解,可生成样本。选择ODE采样时,将噪声项s设为零即可。
3.3.2 估计分数s和速度v
在基于分数的扩散模型中,分数
s
θ
(
x
,
t
)
s_{\theta}(x,t)
sθ(x,t)和速度
v
θ
(
x
,
t
)
v_{\theta}(x,t)
vθ(x,t)可通过参数化估计,损失函数分别为:
L
s
(
θ
)
=
∫
0
T
E
[
∥
σ
t
s
θ
(
x
t
,
t
)
+
ε
∥
2
]
d
t
\mathcal{L}_{s}(\theta)=\int_{0}^{T} \mathbb{E}\left[\left\| \sigma_{t} s_{\theta}\left(x_{t}, t\right)+\varepsilon\right\| ^{2}\right] d t
Ls(θ)=∫0TE[∥σtsθ(xt,t)+ε∥2]dt
L
v
(
θ
)
=
∫
0
T
E
[
∥
v
θ
(
x
t
,
t
)
−
α
˙
t
x
∗
−
σ
˙
t
ε
∥
2
]
d
t
\mathcal{L}_{v}(\theta)=\int_{0}^{T} \mathbb{E}\left[\left\| v_{\theta}\left(x_{t}, t\right)-\dot{\alpha}_{t} x_{*}-\dot{\sigma}_{t} \varepsilon\right\| ^{2}\right] d t
Lv(θ)=∫0TE[∥vθ(xt,t)−α˙tx∗−σ˙tε∥2]dt
训练时采用线性路径
(
α
t
=
1
−
t
,
σ
t
=
t
)
(\alpha_{t}=1 - t, \sigma_{t}=t)
(αt=1−t,σt=t),积分中的时间相关权重在基于分数的模型中起着重要作用。
四、实验
4.1 数据集和训练细节
4.1.1 图像数据集
用FacesHQ 1024×1024探究高分辨率可扩展性,用FacesHQ进行训练和消融实验。在MultiModalCelebA和MS COCO数据集上进行文本条件生成实验,用CLIP文本编码器处理文本。
4.1.2 视频数据集
使用UCF101数据集,随机采样16帧并调整分辨率为256×256。
4.1.3 训练细节
采用AdamW优化器,学习率1e - 4,用VAE编码器提取潜在特征,使用混合精度训练、梯度裁剪和权重衰减。多数实验在4块A100 GPU上进行,部分探索扩展到16块和32块A100 GPU,采样采用ODE采样。
4.2 消融实验
4.2.1 扫描方案消融
在MultiModalCelebA数据集的实验表明,从sweep切换到zigzag扫描方案可提升性能,增加zigzag方案数量也能持续提升性能,且高分辨率下提升更显著。
方案 | FID 5k (256×256) | FID 5k (512×512) |
---|---|---|
Sweep | 158.1 | 162.3 |
Zigzag-1 | 65.7 | 121.0 |
Zigzag-8 | 45.5 | 34.9 |
表1:扫描方案消融实验结果
4.2.2 空间连续性至关重要
实验发现,增加空间连续性可提升性能,随机打乱图像块则导致性能下降,证明空间连续性在Mamba应用于二维序列时非常关键。
空间连续性分析。随着我们逐步增大图像块组的大小,图像块的连续片段也在扩展。这增强了空间连续性,我们发现这会提升在MultiModal-CelebA 256、512数据集上的弗雷歇初始距离(FID)指标表现。
4.2.3 关于网络、FPS和GPU内存的消融实验
对比发现,Zigzag Mamba在FPS和GPU利用率方面表现最佳,且随着Mamba扫描方案增加,对FPS和GPU内存的负担几乎为零。对比发现,Zigzag Mamba在FPS和GPU利用率方面表现最佳,且随着Mamba扫描方案增加,对FPS和GPU内存的负担几乎为零。
4.2.4 顺序感受野
提出“顺序感受野”概念,实验显示Zigzag Mamba在顺序感受野增加时,能保持GPU内存消耗和FPS速率,而其他基线模型则出现FPS下降。
4.2.5 补丁大小
实验表明,随着补丁大小增加,FID变差,说明较小补丁大小更有利于获得最佳性能。
4.3 主要结果
4.3.1 1024×1024 FacesHQ上的主要结果
在高分辨率FacesHQ数据集上,ZigMa性能优于Bidirectional Mamba,随着训练时间延长,有望进一步提升优势。
方法 | FID 5k | FDD 5k |
---|---|---|
Bidirectional Mamba | 51.1 | 66.3 |
ZigMa-16GPU | 37.8 | 50.5 |
ZigMa-32GPU | 26.6 | 31.2 |
表3:1024×1024高分辨率生成结果
4.3.2 COCO数据集
在MS COCO数据集上,ZigMa的Zigzag - 8方法同样优于Bidirectional Mamba和Zigzag - 1,表明分摊扫描方案可提升性能。
方案 | FID 5k |
---|---|
Sweep | 195.1 |
Zigzag-1 | 73.1 |
Bidirection Mamba | 60.2 |
Zigzag-8 | 41.8 |
Zigzag-8 16GPU | 33.8 |
表2:MS-COCO数据集主要结果
4.3.3 UCF101数据集
在UCF101视频数据集上,ZigMa的Factorized 3D Zigzag Mamba方法表现出色,优于Bidirectional Mamba。
方法 | Frame-FID 5k | FVD 5k |
---|---|---|
Bidirectional Mamba-4GPU | 256.1 | 320.2 |
3D Zigzag Mamba -4GPU | 238.1 | 282.3 |
Factorized 3D Zigzag Mamba -4GPU | 216.1 | 210.2 |
Bidirectional Mamba -16GPU | 146.2 | 201.1 |
Factorized 3D Zigzag Mamba -16GPU | 121.2 | 140.1 |
表4:UCF101数据集视频扫描方案结果
4.3.4 可视化
可视化结果显示,ZigMa在不同分辨率下生成的图像质量较高,证明了方法的有效性。
五、结论
本文提出的ZigMa扩散模型,在随机插值框架下,解决了Mamba在二维图像和三维视频建模中的空间连续性问题。通过设计Zigzag Mamba块和相关实验,验证了模型的优势,为Mamba网络设计提供了新的思路。
六、局限性和未来工作
目前ZigMa存在一定局限性,如无法穷举所有空间连续扫描方案,扫描方案多基于经验设定,可能导致性能未达最优。受GPU资源限制,训练时长受限。未来可深入研究Zigzag Mamba在不同领域的应用,充分发挥其长序列建模的可扩展性优势。
七、影响声明
ZigMa提升了Mamba在扩散模型中的可扩展性,有助于生成高保真大图像和实现文本到图像的生成。但如同其他图像合成技术,存在生成有害内容的风险,需关注伦理问题并采取相应保障措施。
参考文献
[1] 主页:ZigMa主页
[2] 文章链接:ZigMa论文
[3] 本文相关代码仓库:ZigMa官方GitHub代码仓库