当前位置: 首页 > news >正文

DIT(Diffusion In Transformer)学习笔记

DIT(Diffusion In Transformer)学习笔记


一、概率建模与数学推导

1. 扩散过程的条件概率重参数化

传统扩散模型的条件概率
传统扩散模型(如DDPM)的逆过程定义为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ t ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t,t), \Sigma_t) pθ(xt1xt)=N(xt1;μθ(xt,t),Σt)
其中均值 μ θ \mu_\theta μθ通过U-Net建模,仅依赖 x t x_t xt和标量时间步 t t t

DIT的条件概率重构
DIT引入Transformer建模时空依赖:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; G θ ( Tokenize ( x t ) + E pos , E time ( t ) ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}\left(x_{t-1}; G_\theta\left( \text{Tokenize}(x_t) + E_{\text{pos}}, E_{\text{time}}(t) \right), \sigma_t^2 I \right) pθ(xt1xt)=N(xt1;Gθ(Tokenize(xt)+Epos,Etime(t)),σt2I)

  • Tokenize:将图像分割为 N × N N \times N N×N的patch(如16×16),生成序列长度 L = ( H / N ) × ( W / N ) L = (H/N) \times (W/N) L=(H/N)×(W/N)的Token序列 Tokenize ( x t ) ∈ R L × d \text{Tokenize}(x_t) \in \mathbb{R}^{L \times d} Tokenize(xt)RL×d
  • E pos E_{\text{pos}} Epos:Patch位置编码矩阵,采用绝对位置编码显式编码每个Patch的空间坐标 ( i , j ) (i,j) (i,j)
  • E time ( t ) E_{\text{time}}(t) Etime(t):时间步嵌入向量,通过频域调制生成:
    E time ( t ) = ∑ k = 0 d / 2 − 1 [ sin ⁡ ( 1 0 4 k / d t ) , cos ⁡ ( 1 0 4 k / d t ) ] E_{\text{time}}(t) = \sum_{k=0}^{d/2-1} \left[ \sin(10^{4k/d} t), \cos(10^{4k/d} t) \right] Etime(t)=k=0d/21[sin(104k/dt),cos(104k/dt)]
    实现不同频率分量的时间信息编码,避免梯度消失。

2. 扩散过程的时空联合建模

DIT将空间(Patch序列)与时间(扩散步)通过Transformer统一建模:

  • 输入处理:图像分割为Patch后,与位置编码 E pos E_{\text{pos}} Epos和时间嵌入 E time ( t ) E_{\text{time}}(t) Etime(t)相加,作为Transformer输入。
  • 条件概率输出:Transformer输出各Patch的均值 μ θ ∈ R L × d \mu_\theta \in \mathbb{R}^{L \times d} μθRL×d,方差保持各Patch独立的高斯分布 σ t 2 I \sigma_t^2 I σt2I,保留空间结构与时间依赖的联合建模。

3. 损失函数的深层设计逻辑

1. Patch级噪声预测损失( L patch \mathcal{L}_{\text{patch}} Lpatch
L patch = E t , i , j ∥ ϵ i , j − MLP ( Attn ( Q i , j , K , V ) ) ∥ 2 \mathcal{L}_{\text{patch}} = \mathbb{E}_{t,i,j} \left\| \epsilon_{i,j} - \text{MLP}(\text{Attn}(Q_{i,j}, K, V)) \right\|^2 Lpatch=Et,i,jϵi,jMLP(Attn(Qi,j,K,V))2

  • 将噪声分解为Patch级 ϵ i , j \epsilon_{i,j} ϵi,j,通过自注意力机制捕捉全局依赖后,MLP预测噪声并计算均方误差,聚焦局部细节与全局结构的联合优化。

2. 序列相关性约束损失( L seq \mathcal{L}_{\text{seq}} Lseq
L seq = E t [ 1 N 2 ∑ i , j KL ( p θ ( z i , j ∣ z < i , j ) ∥ q ( z i , j ∣ z < i , j ) ) ] \mathcal{L}_{\text{seq}} = \mathbb{E}_{t} \left[ \frac{1}{N^2} \sum_{i,j} \text{KL}(p_\theta(z_{i,j}|z_{<i,j}) \| q(z_{i,j}|z_{<i,j})) \right] Lseq=Et[N21i,jKL(pθ(zi,jz<i,j)q(zi,jz<i,j))]

  • 引入自回归先验 q ( z i , j ∣ z < i , j ) q(z_{i,j}|z_{<i,j}) q(zi,jz<i,j)(假设Patch按行优先生成),通过KL散度约束模型生成的条件分布,确保空间结构的逻辑一致性,避免生成不连贯问题。

二、自注意力机制的扩散适应性改进

1. 传统自注意力与扩散感知改进

传统自注意力
Attention ( Q , K , V ) = Softmax ( Q K T d ) V \text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d}} \right)V Attention(Q,K,V)=Softmax(d QKT)V

DIT的扩散感知注意力

(1)时间依赖的温度系数

Temp ( t ) = 1 β t d \text{Temp}(t) = \frac{1}{\sqrt{\beta_t d}} Temp(t)=βtd 1

  • 扩散初期( t → T t \to T tT β t \beta_t βt大)噪声主导,降低温度使注意力分布更尖锐,增强全局关联;后期( t → 0 t \to 0 t0 β t \beta_t βt小)信号主导,升高温度使注意力更平滑,聚焦局部细节。
(2)噪声掩码机制

M i , j = Sigmoid ( MLP ( E time ( t ) ) ) ⋅ I ∣ i − j ∣ < k ( t ) , k ( t ) = ⌈ α t N ⌉ M_{i,j} = \text{Sigmoid}\left( \text{MLP}(E_{\text{time}}(t)) \right) \cdot \mathbb{I}_{|i-j| < k(t)}, \quad k(t) = \lceil \alpha_t N \rceil Mi,j=Sigmoid(MLP(Etime(t)))Iij<k(t),k(t)=αtN

  • 动态控制感受野: α t = ∏ s = 1 t ( 1 − β s ) \alpha_t = \sqrt{\prod_{s=1}^t (1-\beta_s)} αt=s=1t(1βs) 随时间递减, k ( t ) k(t) k(t)从全局(早期大 k ( t ) k(t) k(t))过渡到局部(后期小 k ( t ) k(t) k(t)),减少冗余计算并保留多尺度依赖。
  • 掩码应用于注意力矩阵,实现软掩码与距离掩码的结合:
    Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M d ⋅ Temp ( t ) ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M}{\sqrt{d} \cdot \text{Temp}(t)} \right)V Attn(Q,K,V,t)=Softmax(d Temp(t)QKTM)V

三、与传统扩散模型的对比分析

1. 架构差异对比

维度传统扩散模型(如DDPM)DIT
主干网络U-Net(卷积结构)Transformer(自注意力结构)
条件建模方式时间步 t t t拼接/添加到各层时间嵌入与位置编码共同参与注意力计算
特征交互范围局部感受野(受卷积核限制)全局交互(自注意力机制)
位置信息处理无显式编码显式Patch位置编码
参数量级通常较小(约100M参数)较大(可扩展至10B参数)

2. 理论特性对比

特性DDPM/DDIMDIT
马尔可夫性假设严格马尔可夫链可支持非马尔可夫过程
生成过程可逆性单步不可逆通过自注意力保留路径依赖信息
损失函数形式基于像素级噪声预测联合优化Patch级和序列级损失
收敛速度较慢(需1000+步采样)快速(100-200步达到同等质量)

四、实际操作与工程实现

1. 模型架构配置建议

组件常规配置可调参数说明
Patch尺寸16×16(256×256图像)小尺寸(8×8)保留细节,大尺寸(32×32)降低计算量
Transformer层数12-24层(Base版)/ 36层(Large版)深层数需搭配LayerNorm和残差连接
注意力头数16头(d=1024)头数与维度匹配,避免维度碎片化
位置编码绝对位置编码+可学习参数支持相对位置编码(需调整注意力计算)

2. 训练优化策略

  • 数据预处理:图像归一化至 [ − 1 , 1 ] [-1, 1] [1,1],随机水平翻转;时间步 t t t均匀采样于 [ 1 , T ] [1, T] [1,T](T通常1000,DIT支持更少步数)。
  • 优化器:AdamW( β 1 = 0.9 , β 2 = 0.999 \beta_1=0.9, \beta_2=0.999 β1=0.9,β2=0.999,权重衰减0.05),学习率余弦衰减(初始 1 e − 4 1e-4 1e4,热身5000步)。
  • 混合精度:使用FP16混合精度(PyTorch AMP),减少显存占用,支持更大Batch Size(如128)。

3. 推理加速技术

  • 动态跳步采样:基于非马尔可夫假设,跳过部分时间步(如从1000步降至100-200步),优先在高 β t \beta_t βt阶段大步长跳跃,后期小步长细化细节。
  • 并行化Patch生成:Transformer支持所有Patch并行预测,生成速度随序列长度线性增长,显著提升高分辨率图像(如1024×1024)生成效率。

五、关键技术对比与适用场景

1. 与U-Net扩散模型的核心差异

特性U-Net(DDPM)Transformer(DIT)
空间建模卷积归纳偏置(强)自注意力(弱归纳偏置)
长程依赖依赖跳跃连接直接全局交互(复杂度 O ( L 2 ) O(L^2) O(L2)
分辨率扩展性受限于下采样/上采样层级支持任意Patch尺寸(位置编码适配)
多模态兼容性依赖额外输入拼接自然支持序列输入(文本/图像混合)

2. 适用场景建议

  • 高分辨率图像生成:Transformer的全局建模避免卷积局部信息丢失,细节更锐利(如1024×1024+)。
  • 复杂场景合成:序列相关性约束确保物体间空间关系合理,减少语义冲突(如多物体场景)。
  • 多模态基础模型:Token化输入支持多模态统一编码,便于扩展为跨模态生成模型(如DALL-E类)。

结论

DIT通过Transformer重构扩散模型,实现生成质量与效率的双重突破:

  1. 核心创新:时空联合编码(Patch序列+时间嵌入)、动态注意力机制(温度系数+噪声掩码)、非马尔可夫生成路径(支持跳步采样)。
  2. 技术优势:FID指标超越传统模型44.7%,推理速度提升5-10倍,参数量可扩展至千亿级适配多模态场景。
  3. 工程建议:根据计算资源选择Patch尺寸与模型规模,优化自注意力内存效率(如FlashAttention),优先应用于高分辨率、复杂场景生成任务。

关键公式总结

  • 条件概率重构: p θ ( x t − 1 ∣ x t ) = N ( DiT ( x t , t ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(\text{DiT}(x_t, t), \sigma_t^2 I) pθ(xt1xt)=N(DiT(xt,t),σt2I)
  • 扩散感知注意力: Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M i , j d ⋅ β t ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M_{i,j}}{\sqrt{d} \cdot \sqrt{\beta_t}} \right)V Attn(Q,K,V,t)=Softmax(d βt QKTMi,j)V

代码实现参考
Meta官方仓库:https://github.com/facebookresearch/DiT
重点优化:Patch分割(nn.Unfold)、时间嵌入(频域编码)、动态注意力(掩码矩阵与温度系数)。

相关文章:

  • PID控制中,一阶低通滤波算法
  • c#TCPsever
  • 配置 Odoo 的 PostgreSQL 数据库以允许远程访问的步骤
  • 高级java每日一道面试题-2025年4月30日-基础篇[反射篇]-在反射中,`setAccessible(true)`的作用是什么?
  • LVGL -按键介绍 上
  • Spring AI如何调用本地部署的大模型
  • Learning vtkjs之ImplicitBoolean
  • 脏读、不可重复读、幻读示例
  • Clang-Tidy协助C++编译期检查
  • 在Windows系统上如何用Manifest管理嵌入式项目
  • 《Python实战进阶》No45:性能分析工具 cProfile 与 line_profiler
  • 架构进阶:72页集管IT基础设施蓝图设计方案【附全文阅读】
  • 软考中级-软件设计师 数据库(手写笔记)
  • 算法-冒泡排序
  • Ecology中拦截jquery.ajax请求接口后的数据
  • 【免费数据】2000-2020年中国4km分辨率逐日气象栅格数据(含9个气象变量)
  • windows11 编译 protobuf-3.21.12 c++
  • 大连理工大学选修课——机器学习笔记(4):NBM的原理及应用
  • 机器学习|通过线性回归了解算法流程
  • 制作一款打飞机游戏35:生成系统
  • 中国证券监督管理委员会党委委员、副主席王建军接受审查调查
  • 五一去哪儿| 追着花期去旅行,“赏花经济”绽放文旅新活力
  • 上海“模速空间”:将形成人工智能“北斗七星”和群星态势
  • 成都世运会倒计时100天,中国代表团运动员规模将创新高
  • 习近平对辽宁辽阳市白塔区一饭店火灾事故作出重要指示
  • 发布亮眼一季度报后,东阿阿胶股价跌停:现金流隐忧引发争议