当前位置: 首页 > news >正文

ZigMa:一种DiT风格的Zigzag Mamba扩散模型

摘要

本文聚焦于扩散模型的优化,提出ZigMa(Zigzag Mamba)扩散模型。该模型借助Mamba的长序列建模能力,针对传统扩散模型的可扩展性和复杂度问题进行改进。通过创新的Zigzag Mamba块和随机插值框架的结合,在不同分辨率的图像和视频数据处理上展现出优异性能,为扩散模型的发展提供了新的思路和方法。
关键词:扩散模型;状态空间模型;Zigzag Mamba;随机插值;多维度建模

一、引言

扩散模型在图像处理、视频分析、点云处理等众多领域取得了显著进展,不少模型基于潜在扩散模型(LDM)构建,以UNet为骨干网络。然而,可扩展性问题一直制约着LDM的发展。基于Transformer的结构虽然在可扩展性和多模态训练方面有优势,但注意力机制的二次复杂度成为扩散模型的瓶颈。状态空间模型在长序列建模上潜力巨大,Mamba作为其中的代表,通过高效并行扫描等技术优化长序列建模。但将Mamba拓展到二维图像和三维视频时面临挑战,现有方法存在忽略空间连续性或增加参数负担等问题。
在这里插入图片描述

因此,本文提出Zigzag Mamba(ZigMa),旨在解决这些问题,提升扩散模型性能,并探索随机插值在大规模图像和视频数据中的应用。

二、相关工作

2.1 Mamba

Mamba是新型状态空间模型,在医学成像、图像恢复、自然语言处理等多个领域有广泛应用。与之相关的研究如VisionMamba、S4ND和Mamba-ND等,各有特点。VisionMamba在判别任务中使用双向SSM,计算成本高;S4ND在Mamba推理中引入局部卷积;Mamba-ND在判别任务中考虑多维性。而本文重点在于将扫描复杂度分摊到网络各层,在零参数负担下融入视觉数据归纳偏差。

2.2 扩散模型中的骨干网络

扩散模型常用的骨干网络有UNet和ViT。UNet内存需求大,ViT虽可扩展性强且利于多模态学习,但二次复杂度限制了视觉令牌处理。本文受Mamba启发,探索基于SSM的扩散骨干网络,与DiffSSM、DIS的研究重点不同,更关注骨干网络设计及在复杂视觉数据中的应用。扩散模型常用的骨干网络有UNet和ViT。UNet内存需求大,ViT虽可扩展性强且利于多模态学习,但二次复杂度限制了视觉令牌处理。本文受Mamba启发,探索基于SSM的扩散骨干网络,与DiffSSM、DIS的研究重点不同,更关注骨干网络设计及在复杂视觉数据中的应用。

2.3 扩散模型中的随机微分方程(SDE)和常微分方程(ODE)

基于分数的生成模型依赖SDE,像SMLD和DDPMs。近期研究发现,使用ODE采样器处理扩散SDE可降低采样成本。在随机插值框架中,SiT模型研究了小分辨率下插值方法的作用,本文则将其扩展到高分辨率二维图像和三维视频数据。

三、方法

3.1 背景:状态空间模型

状态空间模型(SSMs)能有效处理长距离依赖关系,计算复杂度与序列长度呈线性关系。其一般形式为:
{ x ′ ( t ) = A ( t ) x ( t ) + B ( t ) u ( t ) y ( t ) = C ( t ) x ( t ) + D ( t ) u ( t ) \begin{cases} x'(t)=A(t)x(t)+B(t)u(t)\\ y(t)=C(t)x(t)+D(t)u(t) \end{cases} {x(t)=A(t)x(t)+B(t)u(t)y(t)=C(t)x(t)+D(t)u(t)
Mamba放宽了SSM参数的时间不变性约束,提升灵活性的同时保持计算效率。本文聚焦于Mamba在扩散模型中的扫描方案,挖掘多维视觉数据的归纳偏差。

3.2 扩散骨干网络:Zigzag Mamba

3.2.1 DiT风格的网络

选择基于AdaLN的ViT框架构建Mamba网络,核心组件为Zigzag扫描。
在这里插入图片描述
ZigMa的骨干网络采用了类似于DiT 的架构,共由L层构成。我们将单扫描Mamba块用作跨不同图像块的主要推理模块。为确保网络具备位置感知能力,我们基于单扫描Mamba设计了一种排列 - 重排列方案。不同的层遵循唯一的重排列操作Ω和反向重排列操作Ω¯对,从而优化了该方法的位置感知能力。

3.2.2 Mamba中的Zigzag扫描

以往方法增加扫描方向会导致内存问题。本文提出在输入前向扫描块前重排令牌,设计公式如下:
{ z Ω i = arrange ( z i , Ω i ) z ‾ Ω i = scan ( z Ω ) z i + 1 = arrange ( z ‾ Ω i , Ω ‾ i ) \begin{cases} z_{\Omega_{i}}=\text{arrange}(z_{i},\Omega_{i})\\ \overline{z}_{\Omega_{i}}=\text{scan}(z_{\Omega})\\ z_{i + 1}=\text{arrange}(\overline{z}_{\Omega_{i}},\overline{\Omega}_{i}) \end{cases} zΩi=arrange(zi,Ωi)zΩi=scan(zΩ)zi+1=arrange(zΩi,Ωi)
同时,强调空间连续性和空间填充,设计了八种空间填充连续方案,用 Ω i = S i % 8 \Omega_{i}=S_{{i\%8}} Ωi=Si%8 表示每一层的方案。
在这里插入图片描述
二维图像扫描。我们的Mamba扫描设计基于子图(a)所示的扫描(sweep-scan)方案。在此基础上,我们开发了子图(b)所示的之字形扫描(zigzag-scan)方案,以增强图像块的连续性,从而最大限度地发挥Mamba块的潜力。由于这些连续扫描存在多种可能的排列方式,在子图(c)中列出了八种最常见的之字形扫描方式。

3.2.3 在Zigzag Mamba上部署文本条件

为解决Mamba在文本条件应用中的不足,设计了基于Mamba块的交叉注意力块,实现长序列建模和多令牌条件(如文本条件),还具备可解释性。
在这里插入图片描述

3.2.4 通过分解空间和时间信息推广到三维视频

在这里插入图片描述

提出三种视频Mamba块变体:sweep - scan直接展平3D特征,不考虑连续性;3D Zigzag尝试保持2D和3D连续性,但优化效果不佳;Factorized 3D Zigzag将空间和时间相关性分解为单独的Mamba块,效果较好。计算分析表明,Zigzag Mamba相比全局自注意力和k - direction mamba,复杂度更低。计算复杂度公式如下:
ζ ( self - attention ) = 4 M D 2 + 2 M 2 D \zeta(\text{self - attention}) = 4MD^{2} + 2M^{2}D ζ(self - attention)=4MD2+2M2D
ζ ( k − m a m b a ) = k × [ 3 M ( 2 D ) N + M ( 2 D ) N 2 ] \zeta(k - mamba)=k×[3M(2D)N + M(2D)N^{2}] ζ(kmamba)=k×[3M(2D)N+M(2D)N2]
ζ ( zigzag ) = 3 M ( 2 D ) N + M ( 2 D ) N 2 \zeta(\text{zigzag}) = 3M(2D)N + M(2D)N^{2} ζ(zigzag)=3M(2D)N+M(2D)N2

3.3 扩散框架:随机插值

3.3.1 基于向量v和分数s的采样

依据相关研究, ( x t ) (x_{t}) (xt)的时间相关概率分布 ( p t ( x ) ) (p_{t}(x)) (pt(x))与反向时间SDE的分布一致,公式为:
d X t = v ( X t , t ) d t + 1 2 w t s ( X t , t ) d t + w t d W ‾ t dX_{t}=v(X_{t},t)dt+\frac{1}{2}w_{t}s(X_{t},t)dt+\sqrt{w_{t}}d\overline{W}_{t} dXt=v(Xt,t)dt+21wts(Xt,t)dt+wt dWt
只要估计出速度(v(x,t))和/或分数(s(x,t))场,就能用于采样。通过概率流ODE或反向时间SDE,从( X T = ε   N ( 0 , I ) X_{T}=\varepsilon ~ N(0, I) XT=ε N(0,I))反向求解,可生成样本。选择ODE采样时,将噪声项s设为零即可。

3.3.2 估计分数s和速度v

在基于分数的扩散模型中,分数 s θ ( x , t ) s_{\theta}(x,t) sθ(x,t)和速度 v θ ( x , t ) v_{\theta}(x,t) vθ(x,t)可通过参数化估计,损失函数分别为:
L s ( θ ) = ∫ 0 T E [ ∥ σ t s θ ( x t , t ) + ε ∥ 2 ] d t \mathcal{L}_{s}(\theta)=\int_{0}^{T} \mathbb{E}\left[\left\| \sigma_{t} s_{\theta}\left(x_{t}, t\right)+\varepsilon\right\| ^{2}\right] d t Ls(θ)=0TE[σtsθ(xt,t)+ε2]dt
L v ( θ ) = ∫ 0 T E [ ∥ v θ ( x t , t ) − α ˙ t x ∗ − σ ˙ t ε ∥ 2 ] d t \mathcal{L}_{v}(\theta)=\int_{0}^{T} \mathbb{E}\left[\left\| v_{\theta}\left(x_{t}, t\right)-\dot{\alpha}_{t} x_{*}-\dot{\sigma}_{t} \varepsilon\right\| ^{2}\right] d t Lv(θ)=0TE[vθ(xt,t)α˙txσ˙tε2]dt
训练时采用线性路径 ( α t = 1 − t , σ t = t ) (\alpha_{t}=1 - t, \sigma_{t}=t) (αt=1t,σt=t),积分中的时间相关权重在基于分数的模型中起着重要作用。

四、实验

4.1 数据集和训练细节

4.1.1 图像数据集

用FacesHQ 1024×1024探究高分辨率可扩展性,用FacesHQ进行训练和消融实验。在MultiModalCelebA和MS COCO数据集上进行文本条件生成实验,用CLIP文本编码器处理文本。

4.1.2 视频数据集

使用UCF101数据集,随机采样16帧并调整分辨率为256×256。

4.1.3 训练细节

采用AdamW优化器,学习率1e - 4,用VAE编码器提取潜在特征,使用混合精度训练、梯度裁剪和权重衰减。多数实验在4块A100 GPU上进行,部分探索扩展到16块和32块A100 GPU,采样采用ODE采样。

4.2 消融实验

4.2.1 扫描方案消融

在MultiModalCelebA数据集的实验表明,从sweep切换到zigzag扫描方案可提升性能,增加zigzag方案数量也能持续提升性能,且高分辨率下提升更显著。

方案FID 5k (256×256)FID 5k (512×512)
Sweep158.1162.3
Zigzag-165.7121.0
Zigzag-845.534.9

表1:扫描方案消融实验结果

4.2.2 空间连续性至关重要

实验发现,增加空间连续性可提升性能,随机打乱图像块则导致性能下降,证明空间连续性在Mamba应用于二维序列时非常关键。
在这里插入图片描述空间连续性分析。随着我们逐步增大图像块组的大小,图像块的连续片段也在扩展。这增强了空间连续性,我们发现这会提升在MultiModal-CelebA 256、512数据集上的弗雷歇初始距离(FID)指标表现。

4.2.3 关于网络、FPS和GPU内存的消融实验

在这里插入图片描述

对比发现,Zigzag Mamba在FPS和GPU利用率方面表现最佳,且随着Mamba扫描方案增加,对FPS和GPU内存的负担几乎为零。对比发现,Zigzag Mamba在FPS和GPU利用率方面表现最佳,且随着Mamba扫描方案增加,对FPS和GPU内存的负担几乎为零。

4.2.4 顺序感受野

提出“顺序感受野”概念,实验显示Zigzag Mamba在顺序感受野增加时,能保持GPU内存消耗和FPS速率,而其他基线模型则出现FPS下降。

4.2.5 补丁大小

实验表明,随着补丁大小增加,FID变差,说明较小补丁大小更有利于获得最佳性能。

4.3 主要结果

4.3.1 1024×1024 FacesHQ上的主要结果

在高分辨率FacesHQ数据集上,ZigMa性能优于Bidirectional Mamba,随着训练时间延长,有望进一步提升优势。

方法FID 5kFDD 5k
Bidirectional Mamba51.166.3
ZigMa-16GPU37.850.5
ZigMa-32GPU26.631.2

表3:1024×1024高分辨率生成结果

4.3.2 COCO数据集

在MS COCO数据集上,ZigMa的Zigzag - 8方法同样优于Bidirectional Mamba和Zigzag - 1,表明分摊扫描方案可提升性能。

方案FID 5k
Sweep195.1
Zigzag-173.1
Bidirection Mamba60.2
Zigzag-841.8
Zigzag-8 16GPU33.8

表2:MS-COCO数据集主要结果

4.3.3 UCF101数据集

在UCF101视频数据集上,ZigMa的Factorized 3D Zigzag Mamba方法表现出色,优于Bidirectional Mamba。

方法Frame-FID 5kFVD 5k
Bidirectional Mamba-4GPU256.1320.2
3D Zigzag Mamba -4GPU238.1282.3
Factorized 3D Zigzag Mamba -4GPU216.1210.2
Bidirectional Mamba -16GPU146.2201.1
Factorized 3D Zigzag Mamba -16GPU121.2140.1

表4:UCF101数据集视频扫描方案结果

4.3.4 可视化

可视化结果显示,ZigMa在不同分辨率下生成的图像质量较高,证明了方法的有效性。
在这里插入图片描述

五、结论

本文提出的ZigMa扩散模型,在随机插值框架下,解决了Mamba在二维图像和三维视频建模中的空间连续性问题。通过设计Zigzag Mamba块和相关实验,验证了模型的优势,为Mamba网络设计提供了新的思路。

六、局限性和未来工作

目前ZigMa存在一定局限性,如无法穷举所有空间连续扫描方案,扫描方案多基于经验设定,可能导致性能未达最优。受GPU资源限制,训练时长受限。未来可深入研究Zigzag Mamba在不同领域的应用,充分发挥其长序列建模的可扩展性优势。

七、影响声明

ZigMa提升了Mamba在扩散模型中的可扩展性,有助于生成高保真大图像和实现文本到图像的生成。但如同其他图像合成技术,存在生成有害内容的风险,需关注伦理问题并采取相应保障措施。

参考文献

[1] 主页:ZigMa主页
[2] 文章链接:ZigMa论文
[3] 本文相关代码仓库:ZigMa官方GitHub代码仓库

相关文章:

  • Stream 流中 flatMap 方法详解
  • ADB简单入门
  • Verilog-HDL/SystemVerilog/Bluespec SystemVerilog vscode 配置
  • 一、蓝绿、灰度、滚动发布有什么不同
  • 网络安全攻防万字全景指南 | 从协议层到应用层的降维打击手册(全程图表对比,包你看到爽)
  • 内存高级话题
  • 如何根据 CUDA 配置安装 PyTorch 和 torchvision(大模型 环境经验)
  • C++学习之nginx+fastDFS
  • 详解Springboot的启动流程
  • 【HarmonyOS NEXT】关键资产存储开发案例
  • 纯内网环境安装1Panel面板与商店应用
  • 版本控制器Git ,Gitee如何连接Linux Gitee和Github区别
  • 信号的捕捉(操作部分)
  • 在linux上启动微服务
  • 前端模块化
  • Kubernetes学习笔记-项目简单部署
  • C语言复习笔记--数组
  • 网络编程之解除udp判断客户端是否断开
  • 调研报告:Hadoop 3.x Ozone 全景解析
  • 网络安全设备配置与管理-实验4-防火墙AAA服务配置
  • 特朗普访问卡塔尔,两国签署多项合作协议
  • 博柏利上财年营收下降17%,计划裁员1700人助推股价涨超18%
  • 夜读丨读《汉书》一得
  • 广西北部湾国际港务集团副总经理潘料庭接受审查调查
  • 江西贵溪:铜板上雕出的国潮美学
  • 法治课|争议中的“行人安全距离”于法无据,考量“注意义务”才更合理