FLUX1. 代码解读
学习材料
https://www.toutiao.com/article/7494605618809029135/?wid=1762309932775
FluxTransformer2DModel模块
FLUX核心模型是FluxTransformer2D模块,本质上是一个扩散Transformer。它接收编码后的图像潜空间(通常是64通道的latent特征)以及文本嵌入,输出去噪后的图像latent。FluxTransormer2DModel内部堆叠了多层Transormer块,其中前几层是双流Dual-Stream块(同时更新图像和文本流,类似论文中的MMDiT),后续层是单流Single-Stream块(仅更新图像流,类似DiT)。以下展示了模型初始化主要组件:
class FluxTransformer2DModel(nn.Layer): def __init__(self, ..., num_layers=19, num_single_layers=38, ...): super().__init__() self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=(16, 56, 56)) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=768 ) # 文本上下文嵌入线性投射:将文本序列embedding降维到inner_dim self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) # 图像latent嵌入线性层:将输入图像通道数投射到inner_dim self.x_embedder = nn.Linear(in_channels, self.inner_dim) # 双流 Transformer 块列表 self.transformer_blocks = nn.LayerList([ FluxTransformerBlock(dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) for _ in range(num_layers) ]) # 单流 Transformer 块列表 self.single_transformer_blocks = nn.LayerList([ FluxSingleTransformerBlock(dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) for _ in range(num_single_layers) ]) # 输出层归一化和线性投射,将inner_dim投射回图像patch的像素维度 self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, ...) self.proj_out = nn.Linear(self.inner_dim, patch_size*patch_size*out_channels)
以上代码展示了模型构造的关键部分:
self.inner_dim 是 Transformer 隐藏维度(例如3072),通常由注意力头数×每头维度计算。
FluxPosEmbed 实例用于生成旋转位置嵌入(Rotary Positional Embedding),适用于二维图像 patch 网格的位置编码。
time_text_embed 是时间步嵌入与文本全局嵌入的融合模块。
context_embedder 是一个线性层,用于将文本编码器输出的上下文序列(如 T5文本 encoder 输出,尺寸 joint_attention_dim=4096)投射到 Transformer 内部维度。换言之,文本每个 Token 的高维 embedding 将降维为与图像 latent 相同的维度,以便进入 Transformer 的注意力计算。
x_embedder 是一个线性层,将输入的图像 latent 通道数(in_channels,默认64)转为 inner_dim。FLUX 模型直接在 VAE 生成的 latent 上应用 Transformer,因此首先用全连接把 latent 特征映射到 Transformer 的隐藏维度。
transformer_blocks 列表包含 num_layers 个 FluxTransformerBlock,即双流 Transformer 块。默认19层,用于同时处理图像和文本两个流。
single_transformer_blocks 列表包含 num_single_layers 个
FluxSingleTransformerBlock,即单流 Transformer 块。默认38层,用于仅处理图像流。
最后,通过 norm_out(持续型 AdaLN 归一化)和 proj_out 输出线性层,将 Transformer 输出变换回原始 latent 形状。
下面我们深入 FluxTransformerBlock(双流块)和
FluxSingleTransformerBlock(单流块)的实现细节,看看双流和单流块在注意力和前馈层上的差异。
FluxTransformerBlock模块
双流块承担跨模态融合的任务,每层同时更新图像隐藏状态和文本隐藏状态。FluxTransformerBlock 内部包含两套并行的子层:一套针对图像 hidden_states,另一套针对文本 encoder_hidden_states。每套都包括 AdaLayerNorm(带时序嵌入调制的层归一化)、多头注意力、和前馈网络(FeedForward),但注意力层是共享的,实现图像-文本的交互。其构造如下:
class FluxTransformerBlock(nn.Layer): def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): super().__init__() # AdaLayerNorm,用Zero初始化偏置和增益,用于图像流 self.norm1 = AdaLayerNormZero(dim) # AdaLayerNorm,用于文本流 self.norm1_context = AdaLayerNormZero(dim) # 多头注意力层(图像-文本双流交互) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=FluxAttnProcessor2_0(), qk_norm=qk_norm, eps=eps, ) # 图像流的LayerNorm + FeedForward self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # 文本流的LayerNorm + FeedForward self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
双流块初始化时,设置了两套归一化和前馈,但共享一个注意力层:
AdaLayerNormZero 是一种可调制的层归一化,初始时 scale 和 bias 为0,但在前向过程中将利用时间步嵌入提供的参数对归一化后的张量施加缩放和平移,以及门控系数。这里分别对图像(norm1)和文本(norm1_context)各用一个 AdaLN。
Attention 层参数设置比较关键:query_dim=dim且added_kv_proj_dim=dim。这意味着查询 Q 来自图像 hidden_states,键 K/值 V 除了处理 Q 自身(图像)的投影外,还会额外投影处理一个维度为 dim 的输入——也就是我们传入的文本 encoder_hidden_states。这相当于在一次注意力计算中同时让图像特征关注自身(自注意力)和文本特征(交叉注意力)。因为 cross_attention_dim=None 且用了 added_kv_proj_dim=dim,所以实现上采用单一 Attention 来处理联合的 KV。context_pre_only=False 表明并非纯先验 context 模式,而是正常交互。总的来说,attn 层实现了一个“双流融合注意力”:Queries 是图像 token,Keys/Values 是图像 token 和文本 token 的混合。
norm2/norm2_context 和 ff/ff_context 则是标准 Transformer 的后半部分(LayerNorm+前馈 MLP),分别应用于图像流和文本流。注意它们没有 Ada 前缀,表示这些层不直接受时间嵌入调制。
FluxTransformerBlock 的 forward 方法更能体现双流机制。关键步骤如下:
def forward( self, hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor, temb: paddle.Tensor, image_rotary_emb=None, joint_attention_kwargs=None,
): # AdaLayerNormZero:归一化 & 提取门控/偏移参数 norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)joint_attention_kwargs = joint_attention_kwargs or {} # 多头注意力:将规范化后的图像&文本特征输入注意力层 attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) # Attention输出可能包含2或3个张量 if len(attention_outputs) == 2: attn_output, context_attn_output = attention_outputs elif len(attention_outputs) == 3: attn_output, context_attn_output, ip_attn_output = attention_outputs # 将注意力输出应用Gate并加入残差 attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output # 前馈层 (图像流) norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output if len(attention_outputs) == 3: hidden_states = hidden_states + ip_attn_output # 前馈层 (图像流) norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output if len(attention_outputs) == 3: hidden_states = hidden_states + ip_attn_output
这段逻辑实现了图像-文本双流 Transformer 层的前向计算,流程可总结如下:
AdaLayerNorm 调制:使用当前扩散时间步的嵌入,对图像和文本两个输入分别做 AdaLN 归一化。AdaLayerNormZero 不仅返回归一化后的张量,还提取出若干调制参数:对于每个流,产生用于注意力输出的门控系数 gate_msa、用于 MLP 输入的仿射变换参数 scale_mlp 和 shift_mlp、以及MLP输出的门控系数 gate_mlp。这些参数形状一般是[batch,dim]或[batch,],后续用来调整该层的输出幅度。
联合注意力计算:将归一化后的图像特征 norm_hidden_states 作为 Query,文本特征
norm_encoder_hidden_states 作为附加的 KV 输入,喂给 self.attn。image_rotary_emb 则是由 FluxPosEmbed 生成的旋转位置嵌入参数(正弦/余弦基),用于在 Attention 内部对 Q,K 应用位置编码,特别适用于图像 patch 的2D 位置信息。
残差连接(注意力层):对注意力结果施加 AdaLN 提供的门控系数,然后加回各自的残差通道。这样图像特征经过自注意力并融合了文本信息,文本特征也在交互中得到更新(捕获与图像的关联)。
前馈层(图像):对更新后的 hidden_states 再做一层标准 LayerNorm,然后利用 AdaLN 的 scale 和 shift 参数对其按元素缩放和平移。随后通过 FeedForward MLP(两层全连接+激活,已经在构造时定义)变换得到 ff_output。再乘以 gate_mlp(AdaLN 提供的 MLP 门控系数)进行缩放,最后残差连接加回到 hidden_states。
前馈层(文本):类似地,对文本 encoder_hidden_states 应用 LayerNorm、AdaLN 缩放偏移,再经过文本流自己的 ff_context MLP,乘以 c_gate_mlp 后加回。这样文本特征也通过 MLP 得到更新。
返回结果:最终,该层输出更新后的 encoder_hidden_states(文本)和 hidden_states(图像)。这两个将作为下一层 FluxTransformerBlock 的输入,实现逐层交替强化图像和文本的表示。注意在 FP16情况下对结果裁剪,以避免数值溢出。
双流 Transformer 块每层都让图像和文本特征互相融合:图像 latent 通过交叉注意力“看”文本 embedding,文本 embedding 也被图像特征影响更新。这类似 Stable Diffusion3的双流 Transformer 设计。Flux 将若干这样的层堆叠,使得高层的图像特征已深度融合文本语义。
为啥要rotary?
FluxSingleTransformerBlock模块
经过多层双流块后,FluxTransformer2DModel 后半部分采用单流 Transformer 块,此时文本特征不再更新,仅作为条件参与图像 Transformer。
FluxSingleTransformerBlock 与双流块的区别在于:它只维护图像一个流,并且将注意力和前馈合并为一个更紧凑的结构。下面是 FluxSingleTransformerBlock 的主要实现:
class FluxSingleTransformerBlock(nn.Layer): def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) # 单流AdaLN(只返回一个门控参数) self.norm = AdaLayerNormZeroSingle(dim) # 简化的MLP:先线性扩张,再激活 self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") # 输出线性:将 [attn输出 + mlp输出] 合并回 dim self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) # 注意力层:不引入额外context,pre_only=True用于优化 processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True ) def forward( self, hidden_states: paddle.Tensor, temb: paddle.Tensor, image_rotary_emb=None, joint_attention_kwargs=None, ): residual = hidden_states # AdaLayerNormSingle:归一化 + 提取单一门控参数 norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) # 将注意力输出和MLP输出拼接,然后通过线性层融合 hidden_states = paddle.concat([attn_output, mlp_hidden_states], axis=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states # FP16剪裁 if hidden_states.dtype == paddle.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states
单流块仍然利用 AdaLN(但 Simple 版)对特征进行归一化和调制,但内部流程相较双流块有几点不同:
没有文本 context 输入:Attention 层的 cross_attention_dim=None 且不使用 added_kv_proj_dim,因此这个注意力就是标准自注意力,仅作用于图像 latent 序列自身。也就是说,从单流块开始,模型不再更新文本 embedding,文本提供的条件已经在双流块融合完毕。
AdaLayerNormZeroSingle 返回的不是五个参数,而是 norm 后的 hidden 和一个门控系数 gate。Single 版的 AdaLN 对结构进行了简化,因为此时我们不再需要分别对注意力和 MLP 输出进行不同门控(它直接把二者 concat 后一起门控)。
前馈合并简化:Single 块中,将注意力输出和 MLP 输出拼接后一起投射,然后通过 proj_out 线性层将拼接后的向量映射回 dim 长度,再乘以 AdaLN 提供的 gate 系数,最后加上残差。
这样的设计实质上等效于 Transformer 中的并行 FFN 和 Attention 路径,只是这里不是先后顺序叠加,而是并联后融合。
单流块执行标准 Transformer 对图像 latent 的自注意力和前馈,但由于文本信息已嵌入,它不显式处理文本。Single 块采用更紧凑的并行融合方式,最终继续细化图像 latent。
Prompt Embedding 融合机制
FluxPipeline 需要将用户输入的文本提示经过两个文本编码器(CLIP 和 T5)得到 embedding,然后送入 FluxTransformer2DModel。该过程涉及两个方面:
文本序列Embedding(prompt_embeds):由 T5编码器输出,包含文本每个 token 的上下文表示([batch,seq_len,4096]),用于双流块中的文本流 encoder_hidden_states。在进入 FluxTransformer2DModel 前,这个高维序列会通过 self.context_embedder 降维到 inner_dim(如3072)。
文本全局Embedding(pooled_prompt_embeds):由 CLIP 文本模型输出,比如取[EOS]token 的隐藏态作为整体语义表示([batch,768])。这相当于一句话的语义句向量,供模型全局调控使用。
FLUX 将扩散时间步和上述全局文本 embedding 融合为一个向量,用于调制 Transformer 层。这由
CombinedTimestepTextProjEmbeddings 模块完成。其代码表示如下
class CombinedTimestepTextProjEmbeddings(nn.Layer): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.cast(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning
CombinedTimestepTextProjEmbeddings 将两种不同来源的 embedding 简单逐元素相加。这样产生的输出向量既包含当前扩散步骤的信息,也包含了与提示文本内容相关的全局语义。这个 temb 将在 Transformer 每层的 AdaLayerNorm 中使用,从而影响模型中不同层的归一化和门控参数,实现条件控制。
在 Prompt Embedding 生成流程中,当用户提供 prompt 文本时,FluxPipeline 会用 CLIP 的 tokenizer 和文本编码器编码一次,用 T5的 tokenizer 和编码器编码一次,代码表示如下
def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], num_images_per_prompt: int = 1, prompt_embeds: Optional[paddle.Tensor] = None, pooled_prompt_embeds: Optional[paddle.Tensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None,
): prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = paddle.zeros([prompt_embeds.shape[1], 3]).astype(dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids
包含两个过程:
CLIP 编码:得到 pooled_prompt_embeds,即 CLIP 文本模型的池化输出。
T5编码:得到 prompt_embeds,即 T5编码器最后一层隐状态序列(长度等于文本 token 数)。
