Flux.1系列模型解析--Kontext
论文链接:FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space
文章目录
- 简介
- 具体实现
- 图片tokens序列构建
- 训练目标
- 对抗性扩散蒸馏采样
- 训练过程
简介
Flux.1 Kontext是将图像生成和编辑任务统一的生成流匹配模型,其通过整合文本和图像输入中的语义上下文,生成全新的输出视图。Flux.1 Kontext采用简单的序列拼接,在单一的统一架构中同时处理局部编辑和生成式上下文任务;其在多轮操作中对物体和字符有较强的保留能力,相较于其他SOTA编辑模型生成速度也更快;为了验证上述改进,论文提出了KontextBench,包含1026个图像-提示词对数据,覆盖局部编辑、全局编辑、字符参考、风格参考和文本编辑五类任务。评估结果表明,FLUX.1 Kontext 在单轮质量和多轮一致性方面均表现卓越。
Flux.1 Kontext是一个简单的流匹配模型,仅通过在上下文和指令拼接的tokens序列上基于速度预测目标训练而成,有以下优点:
- 角色一致性:在角色保留方面表现出色,在多轮迭代编辑中仍能保持角色的一致性
- 生成速度:无论是文生图还是图生图,推理硬件性能足够时,生成 1024×1024 图像均能达到 3-5 秒的速度
- 迭代式应用:快速的推理速度和稳健的一致性,使用户能够通过多次连续编辑来优化图像,且视觉偏差极小
具体实现
Flux.1 Kontext模型的目标是训练一个可以基于文本提示词和一系列参考图片的联合特征生成图片的模型,更正式的表达是通过训练逼近以下条件分布:
KaTeX parse error: \tag works only in display equations
其中 xxx为目标图像, yyy为上下文图像或空集 ϕ\phiϕ, ccc为自然语言形式的指令。与传统的文生图不同,公式(1)这个目标学习图像之间通过条件 ccc所建立的关系,从而使同一个网络能够实现,当 y≠ϕy \neq \phiy=ϕ时可执行图像编辑任务,当 y=ϕy = \phiy=ϕ时执行全新的图片生成任务。
记目标输出图片的集合为 X\mathcal{X}X,公式(1)中的输入 x∈Xx \in \mathcal{X}x∈X,可选的上下文图像 y∈X∪{ϕ}y \in \mathcal{X} \cup \{\phi\}y∈X∪{ϕ}, c∈Cc \in Cc∈C为文本提示。对条件分布 pθ(x∣y,c)p_{\theta}(x|y,c)pθ(x∣y,c)建模,使得同一网络在不同条件 yyy时可进行上下文编辑、局部编辑、自由文生图任务。训练时以Flux.1的文生图模型初始化模型,基于收集、整理的数百万个 (x∣y,c)(x|y,c)(x∣y,c)进行后训练优化;实际操作并非在像素空间中进行,而是对离散化的图片tokens序列进行操作。
图片tokens序列构建
使用VAE模型的编码器将上下文图片转换为隐空间向量,然后按2*2为一个patch将图片离散为toknes序列,此才为上文提到的 yyy,将其拼接到目标图片的tokens序列 xxx后面后,作为backbone的输入。此处的拼接是直接在长度维度上进行,优势是支持不同的输入/输出分辨率和宽高比,并且可以轻松将上下文图片扩展为多张图片,即 y1,y2,...,yNy_1,y_2,...,y_Ny1,y2,...,yN。开发人员也尝试在最后的特征维度上拼接,但实验结果表明这个方案效果并不好。
前文中提到过,Flux.1系列模型构建的是三维的时空位置编码,具体而言,初始化的绝对位置由三元组 u=(t,h,w)u=(t,h,w)u=(t,h,w)表示,对于目标图片的tokens序列位置表示为 ux=(0,h,w)u_x=(0,h,w)ux=(0,h,w),而上下文图片的tokens序列位置表示为 uyi=(i,h,w),i=1,...,Nu_{y_i}=(i,h,w),\quad i=1,...,Nuyi=(i,h,w),i=1,...,N。
训练目标
训练目标就是正常基于最有传输路径的流匹配损失函数,如下所示:
Lθ=Et∼p(t),x,y,c[∣∣vθ(zt,t,y,c)−(ϵ−x)∣∣22](2)\mathcal{L}_{\theta}=\mathbb{E}_{t \sim p(t),x,y,c}[||v_{\theta}(z_t,t,y,c)-(\epsilon-x)||^2_2] \tag2Lθ=Et∼p(t),x,y,c[∣∣vθ(zt,t,y,c)−(ϵ−x)∣∣22](2)
其中 ztz_tzt是目标 xxx与纯噪声 ϵ∼N(0,1)\epsilon \sim \mathcal{N}(0,1)ϵ∼N(0,1)之间的线性插值,即 zt=(1−t)x+tϵz_t=(1-t)x+t\epsilonzt=(1−t)x+tϵ;时间步 ttt的采样分布遵循对数正态偏移调度,训练过程中基于图片的分辨率改变内部系数。当进行纯文生图时会省略所有上下文特征 yyy,保留模型的文生图能力。
对抗性扩散蒸馏采样
流匹配模型训练结束后,往往会通过数值求解的方法进行采样,往往需要较多的迭代采样步数来得到较好的结果,导致生图效率较慢、部署成本高。此外,当前的常用cfg引导方式会引入视觉伪影,如生成过饱和的图片。为了解决上述问题,采用隐空间对抗性扩散蒸馏(LADD),通过对抗性训练减少采样步数的同时提高生图质量。
训练过程
从纯文生图模型出发,基于公式(2)进行联合全量微调,使其同时适配图生图和文生图任务。虽然当前训练范式支持多张上下文图片,但目前只进行了单张上下文图片训练。首先训练出Flux.1 Kontext pro模型,然后通过引导机制将其能力蒸馏至12B的模型中,即Flux.1 Kontext dev模型。为优化Flux.1 Kontext dev在编辑任务上的性能,专门进一步进行图生图训练,没有进一步进行文生图训练。训练过程中融入了安全训练措施,包括基于分类器的过滤和对抗性训练,以防止生成私密图像或儿童性虐待内容。
将FSDP2配合混合精度使用,即所有聚集(all-gather)操作以 bfloat16 精度执行,而梯度计算gradient reduce-scatter则使用 float32 精度,以提高数值稳定性;使用选择性激活降低最大VRAM使用。为提高吞吐量,采用了Flash Attention 3以及对各个Transformer块的区域编译技术。