DIT(Diffusion In Transformer)学习笔记
DIT(Diffusion In Transformer)学习笔记
一、概率建模与数学推导
1. 扩散过程的条件概率重参数化
传统扩散模型的条件概率
传统扩散模型(如DDPM)的逆过程定义为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ t ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t,t), \Sigma_t) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σt)
其中均值 μ θ \mu_\theta μθ通过U-Net建模,仅依赖 x t x_t xt和标量时间步 t t t。
DIT的条件概率重构
DIT引入Transformer建模时空依赖:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; G θ ( Tokenize ( x t ) + E pos , E time ( t ) ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}\left(x_{t-1}; G_\theta\left( \text{Tokenize}(x_t) + E_{\text{pos}}, E_{\text{time}}(t) \right), \sigma_t^2 I \right) pθ(xt−1∣xt)=N(xt−1;Gθ(Tokenize(xt)+Epos,Etime(t)),σt2I)
- Tokenize:将图像分割为 N × N N \times N N×N的patch(如16×16),生成序列长度 L = ( H / N ) × ( W / N ) L = (H/N) \times (W/N) L=(H/N)×(W/N)的Token序列 Tokenize ( x t ) ∈ R L × d \text{Tokenize}(x_t) \in \mathbb{R}^{L \times d} Tokenize(xt)∈RL×d。
- E pos E_{\text{pos}} Epos:Patch位置编码矩阵,采用绝对位置编码显式编码每个Patch的空间坐标 ( i , j ) (i,j) (i,j)。
- E time ( t ) E_{\text{time}}(t) Etime(t):时间步嵌入向量,通过频域调制生成:
E time ( t ) = ∑ k = 0 d / 2 − 1 [ sin ( 1 0 4 k / d t ) , cos ( 1 0 4 k / d t ) ] E_{\text{time}}(t) = \sum_{k=0}^{d/2-1} \left[ \sin(10^{4k/d} t), \cos(10^{4k/d} t) \right] Etime(t)=k=0∑d/2−1[sin(104k/dt),cos(104k/dt)]
实现不同频率分量的时间信息编码,避免梯度消失。
2. 扩散过程的时空联合建模
DIT将空间(Patch序列)与时间(扩散步)通过Transformer统一建模:
- 输入处理:图像分割为Patch后,与位置编码 E pos E_{\text{pos}} Epos和时间嵌入 E time ( t ) E_{\text{time}}(t) Etime(t)相加,作为Transformer输入。
- 条件概率输出:Transformer输出各Patch的均值 μ θ ∈ R L × d \mu_\theta \in \mathbb{R}^{L \times d} μθ∈RL×d,方差保持各Patch独立的高斯分布 σ t 2 I \sigma_t^2 I σt2I,保留空间结构与时间依赖的联合建模。
3. 损失函数的深层设计逻辑
1. Patch级噪声预测损失( L patch \mathcal{L}_{\text{patch}} Lpatch)
L patch = E t , i , j ∥ ϵ i , j − MLP ( Attn ( Q i , j , K , V ) ) ∥ 2 \mathcal{L}_{\text{patch}} = \mathbb{E}_{t,i,j} \left\| \epsilon_{i,j} - \text{MLP}(\text{Attn}(Q_{i,j}, K, V)) \right\|^2 Lpatch=Et,i,j∥ϵi,j−MLP(Attn(Qi,j,K,V))∥2
- 将噪声分解为Patch级 ϵ i , j \epsilon_{i,j} ϵi,j,通过自注意力机制捕捉全局依赖后,MLP预测噪声并计算均方误差,聚焦局部细节与全局结构的联合优化。
2. 序列相关性约束损失( L seq \mathcal{L}_{\text{seq}} Lseq)
L seq = E t [ 1 N 2 ∑ i , j KL ( p θ ( z i , j ∣ z < i , j ) ∥ q ( z i , j ∣ z < i , j ) ) ] \mathcal{L}_{\text{seq}} = \mathbb{E}_{t} \left[ \frac{1}{N^2} \sum_{i,j} \text{KL}(p_\theta(z_{i,j}|z_{<i,j}) \| q(z_{i,j}|z_{<i,j})) \right] Lseq=Et[N21i,j∑KL(pθ(zi,j∣z<i,j)∥q(zi,j∣z<i,j))]
- 引入自回归先验 q ( z i , j ∣ z < i , j ) q(z_{i,j}|z_{<i,j}) q(zi,j∣z<i,j)(假设Patch按行优先生成),通过KL散度约束模型生成的条件分布,确保空间结构的逻辑一致性,避免生成不连贯问题。
二、自注意力机制的扩散适应性改进
1. 传统自注意力与扩散感知改进
传统自注意力
Attention ( Q , K , V ) = Softmax ( Q K T d ) V \text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d}} \right)V Attention(Q,K,V)=Softmax(dQKT)V
DIT的扩散感知注意力
(1)时间依赖的温度系数
Temp ( t ) = 1 β t d \text{Temp}(t) = \frac{1}{\sqrt{\beta_t d}} Temp(t)=βtd1
- 扩散初期( t → T t \to T t→T, β t \beta_t βt大)噪声主导,降低温度使注意力分布更尖锐,增强全局关联;后期( t → 0 t \to 0 t→0, β t \beta_t βt小)信号主导,升高温度使注意力更平滑,聚焦局部细节。
(2)噪声掩码机制
M i , j = Sigmoid ( MLP ( E time ( t ) ) ) ⋅ I ∣ i − j ∣ < k ( t ) , k ( t ) = ⌈ α t N ⌉ M_{i,j} = \text{Sigmoid}\left( \text{MLP}(E_{\text{time}}(t)) \right) \cdot \mathbb{I}_{|i-j| < k(t)}, \quad k(t) = \lceil \alpha_t N \rceil Mi,j=Sigmoid(MLP(Etime(t)))⋅I∣i−j∣<k(t),k(t)=⌈αtN⌉
- 动态控制感受野: α t = ∏ s = 1 t ( 1 − β s ) \alpha_t = \sqrt{\prod_{s=1}^t (1-\beta_s)} αt=∏s=1t(1−βs)随时间递减, k ( t ) k(t) k(t)从全局(早期大 k ( t ) k(t) k(t))过渡到局部(后期小 k ( t ) k(t) k(t)),减少冗余计算并保留多尺度依赖。
- 掩码应用于注意力矩阵,实现软掩码与距离掩码的结合:
Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M d ⋅ Temp ( t ) ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M}{\sqrt{d} \cdot \text{Temp}(t)} \right)V Attn(Q,K,V,t)=Softmax(d⋅Temp(t)QKT⊙M)V
三、与传统扩散模型的对比分析
1. 架构差异对比
维度 | 传统扩散模型(如DDPM) | DIT |
---|---|---|
主干网络 | U-Net(卷积结构) | Transformer(自注意力结构) |
条件建模方式 | 时间步 t t t拼接/添加到各层 | 时间嵌入与位置编码共同参与注意力计算 |
特征交互范围 | 局部感受野(受卷积核限制) | 全局交互(自注意力机制) |
位置信息处理 | 无显式编码 | 显式Patch位置编码 |
参数量级 | 通常较小(约100M参数) | 较大(可扩展至10B参数) |
2. 理论特性对比
特性 | DDPM/DDIM | DIT |
---|---|---|
马尔可夫性假设 | 严格马尔可夫链 | 可支持非马尔可夫过程 |
生成过程可逆性 | 单步不可逆 | 通过自注意力保留路径依赖信息 |
损失函数形式 | 基于像素级噪声预测 | 联合优化Patch级和序列级损失 |
收敛速度 | 较慢(需1000+步采样) | 快速(100-200步达到同等质量) |
四、实际操作与工程实现
1. 模型架构配置建议
组件 | 常规配置 | 可调参数说明 |
---|---|---|
Patch尺寸 | 16×16(256×256图像) | 小尺寸(8×8)保留细节,大尺寸(32×32)降低计算量 |
Transformer层数 | 12-24层(Base版)/ 36层(Large版) | 深层数需搭配LayerNorm和残差连接 |
注意力头数 | 16头(d=1024) | 头数与维度匹配,避免维度碎片化 |
位置编码 | 绝对位置编码+可学习参数 | 支持相对位置编码(需调整注意力计算) |
2. 训练优化策略
- 数据预处理:图像归一化至 [ − 1 , 1 ] [-1, 1] [−1,1],随机水平翻转;时间步 t t t均匀采样于 [ 1 , T ] [1, T] [1,T](T通常1000,DIT支持更少步数)。
- 优化器:AdamW( β 1 = 0.9 , β 2 = 0.999 \beta_1=0.9, \beta_2=0.999 β1=0.9,β2=0.999,权重衰减0.05),学习率余弦衰减(初始 1 e − 4 1e-4 1e−4,热身5000步)。
- 混合精度:使用FP16混合精度(PyTorch AMP),减少显存占用,支持更大Batch Size(如128)。
3. 推理加速技术
- 动态跳步采样:基于非马尔可夫假设,跳过部分时间步(如从1000步降至100-200步),优先在高 β t \beta_t βt阶段大步长跳跃,后期小步长细化细节。
- 并行化Patch生成:Transformer支持所有Patch并行预测,生成速度随序列长度线性增长,显著提升高分辨率图像(如1024×1024)生成效率。
五、关键技术对比与适用场景
1. 与U-Net扩散模型的核心差异
特性 | U-Net(DDPM) | Transformer(DIT) |
---|---|---|
空间建模 | 卷积归纳偏置(强) | 自注意力(弱归纳偏置) |
长程依赖 | 依赖跳跃连接 | 直接全局交互(复杂度 O ( L 2 ) O(L^2) O(L2)) |
分辨率扩展性 | 受限于下采样/上采样层级 | 支持任意Patch尺寸(位置编码适配) |
多模态兼容性 | 依赖额外输入拼接 | 自然支持序列输入(文本/图像混合) |
2. 适用场景建议
- 高分辨率图像生成:Transformer的全局建模避免卷积局部信息丢失,细节更锐利(如1024×1024+)。
- 复杂场景合成:序列相关性约束确保物体间空间关系合理,减少语义冲突(如多物体场景)。
- 多模态基础模型:Token化输入支持多模态统一编码,便于扩展为跨模态生成模型(如DALL-E类)。
结论
DIT通过Transformer重构扩散模型,实现生成质量与效率的双重突破:
- 核心创新:时空联合编码(Patch序列+时间嵌入)、动态注意力机制(温度系数+噪声掩码)、非马尔可夫生成路径(支持跳步采样)。
- 技术优势:FID指标超越传统模型44.7%,推理速度提升5-10倍,参数量可扩展至千亿级适配多模态场景。
- 工程建议:根据计算资源选择Patch尺寸与模型规模,优化自注意力内存效率(如FlashAttention),优先应用于高分辨率、复杂场景生成任务。
关键公式总结
- 条件概率重构: p θ ( x t − 1 ∣ x t ) = N ( DiT ( x t , t ) , σ t 2 I ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(\text{DiT}(x_t, t), \sigma_t^2 I) pθ(xt−1∣xt)=N(DiT(xt,t),σt2I)
- 扩散感知注意力: Attn ( Q , K , V , t ) = Softmax ( Q K T ⊙ M i , j d ⋅ β t ) V \text{Attn}(Q,K,V,t) = \text{Softmax}\left( \frac{QK^T \odot M_{i,j}}{\sqrt{d} \cdot \sqrt{\beta_t}} \right)V Attn(Q,K,V,t)=Softmax(d⋅βtQKT⊙Mi,j)V
代码实现参考
Meta官方仓库:https://github.com/facebookresearch/DiT
重点优化:Patch分割(nn.Unfold
)、时间嵌入(频域编码)、动态注意力(掩码矩阵与温度系数)。