Flux.1系列模型解析--Flux.1
文章目录
- 简介
- 文本编码器
- 旋转位置编码
- Flux.1 backbone
- 双流模块
- 单流模块
- VAE
- 生图+采样
简介
Flux.1模型有三个版本,分别是pro、dev和schnell,三个模型性能依次递减,但生图效率依次提高。dev和schnell基于pro模型蒸馏而来,pro模型只能通过api访问,而dev、shcnell模型可获取具体权重,bfl并没有对Flux.1系列模型架构进行过多展示,只表明基于多模态和并行扩散 Transformer 模块的混合架构,参数扩展到了12B;通过基于流匹配范式训练,且引入旋转位置编码和并行注意力层来提高模型性能并提升硬件效率。
虽然bfl没有进一步公布详细的技术文档,但其在github上开源了推理代码,可以基于推理代码梳理出整个模型架构,图1就是reddit论坛上社区开发者发布的Flux.1模型架构图。Flux.1模型基于DiT架构,与LLMs相同使用RoPE来表征图片位置信息,先使用双流块、再使用单流块实现图像隐空间和文本编码空间的对齐,最终舍弃文本tokens,对图像tokens进行解码得到图片。图1要从下向上看,后续将针对其中的主要模块或概念结合推理代码进行说明。
文本编码器
如图1所示,文本提示词会经过T5 Encoder和CLIP两个文本编码器提取文本特征,官方推理代码具体实现如下所示,实现极其简洁,基于transformers库提供的接口直接将两个文本编码器封装在一个类中,使用时根据version
参数自动识别、初始化对应实例。
from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizerclass HFEmbedder(nn.Module): # 可分别初始化clip和t5的文本编码器类def __init__(self, version: str, max_length: int, **hf_kwargs):super().__init__()self.is_clip = version.startswith("openai") # 判断是clip还是t5self.max_length = max_lengthself.output_key = "pooler_output" if self.is_clip else "last_hidden_state"if self.is_clip:self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) # 初始化clip的tokenizerself.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) # 初始化clip的文本编码器else:self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) # 初始化t5的tokenizerself.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) # 初始化t5的文本编码器self.hf_module = self.hf_module.eval().requires_grad_(False) # 设置为eval模式,并禁用梯度计算def forward(self, text: list[str]) -> Tensor:batch_encoding = self.tokenizer(text, # 输入文本列表truncation=True, # 允许截断max_length=self.max_length, # 最大长度return_length=False, # 不返回长度return_overflowing_tokens=False, # 不返回溢出tokenpadding="max_length", # 填充到最大长度return_tensors="pt", # 返回pytorch张量)outputs = self.hf_module(input_ids=batch_encoding["input_ids"].to(self.hf_module.device),attention_mask=None, # 不使用注意力掩码output_hidden_states=False, # 不输出隐藏状态)return outputs[self.output_key].bfloat16() # 返回输出
旋转位置编码
Flux.1模型中使用的是二维旋转位置编码,其会同时对文本和图像进行处理。Flux.1模型中完整的旋转基本由三部分组成,分别是。先分别基于文本和图片位置索引张量txt_ids
和img_ids
构建二维位置编码张量的EmbedND
模块、对查询张量、键张量应用旋转位置编码的apply_rope
函数和对带有旋转位置编码信息的张量进行注意力计算的attention
函数。具体实现如下:
# src/flux/modules/layers.py
class EmbedND(nn.Module):"""N维位置编码模块参数:dim (int): 位置编码维度, 通常为64或128theta (int): RoPE旋转角度参数, 通常为10000axes_dim (list[int]): 每个轴的编码维度, 如[32,32]表示2D位置编码,每个维度32"""def __init__(self, dim: int = 64, theta: int = 10000, axes_dim: list[int] = [32, 32]):super().__init__()self.dim = dim # 位置编码总维度,等于axes_dim之和self.theta = theta # RoPE旋转参数self.axes_dim = axes_dim # 每个轴的编码维度,如[32,32]表示2D位置编码,每个维度32def forward(self, ids: Tensor) -> Tensor:"""前向传播参数:ids: shape为[batch_size, seq_len, n_axes]的位置索引张量,此处的seq_len是所有轴的编码维度之和返回:shape为[batch_size, 1, dim, 2, 2]的位置编码张量"""n_axes = ids.shape[-1] # 获取轴数,如2表示2D位置# 对每个轴应用RoPE编码并拼接emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],dim=-3,) # 每个轴的编码在-3维度上拼接,因为最后两个维度是旋转矩阵# 增加head维度return emb.unsqueeze(1) # [B, 1, D, 2, 2],D为总编码维度,2x2为RoPE的旋转矩阵# src/flux/math.py
import torch
from einops import rearrange
from torch import Tensordef attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:"""注意力机制q: query张量 [batch, heads, seq_len, head_dim]k: key张量 [batch, heads, seq_len, head_dim]v: value张量 [batch, heads, seq_len, head_dim]pe: 位置编码张量 [batch, 1, dim, 2, 2]"""q, k = apply_rope(q, k, pe) # 将预计算的rope旋转矩阵应用于q,kx = torch.nn.functional.scaled_dot_product_attention(q, k, v) # 计算注意力x = rearrange(x, "B H L D -> B L (H D)") # 将多头注意力组合回整体return x # [batch, seq_len, heads*head_dim]def rope(pos: Tensor, dim: int, theta: int) -> Tensor:assert dim % 2 == 0# 计算每个位置的频率缩放因子;先生成序列 [0, 2, 4, ..., dim-2],然后除以 dim,得到得到 [0, 2/dim, 4/dim, ..., (dim-2)/dim]scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dimomega = 1.0 / (theta**scale) # 计算最终的角频率 ω_i = 1/θ^(2i/dim)out = torch.einsum("...n,d->...nd", pos, omega) # Einstein 求和约定计算位置和频率的外积,shape: [batch, seq_len, dim//2]out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) # 构建旋转矩阵,shape: [batch, seq_len, dim//2, 4]out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) # 重排列成矩阵形式,将最后一个维度4拆分成2*2,shape: [batch, seq_len, dim//2, 2, 2]return out.float()def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:# 输入的q、k张量的最后一维拆分为两个维度,相当于构建复述形式;[batch, heads, seq_len, head_dim] --> [batch, heads, seq_len, head_dim//2, 1, 2],新增加的维度1是为了广播计算添加xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)# 进行旋转变换;freqs_cis[..., 0]、freqs_cis[..., 1]是一个行数为2的列向量,xq_[..., 0]、xq_[..., 1]、xk_[..., 0]、xk_[..., 1]是一个列数为2的行向量xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) # 将结果重排列回原来的形状
Flux.1 backbone
Flux.1模型的backbone实现如下所示,可与图1对比。img对应图1中的latent、txt
对应经过T5 Encoder提取的文本嵌入、y
对应经过CLIP提取的文本嵌入。基于timesteps
、guidance
初始化的编码特征会以及clip文本嵌入三者相加为vec
,会在整个迭代预测过程中作为调制向量,用于计算对应的调制项。在初始化旋转位置编码pe
后,先以img
、txt
、vec
、pe
为输入经过多个双流模块的计算;然后将img
、txt
拼接为一个张量,再经过多个单流模块的计算;最后只截取图片序列,将其输出层归一化、线性层等模块输出最终的latent。
class Flux(nn.Module):"""Transformer model for flow matching on sequences."""def __init__(self, params: FluxParams):super().__init__()self.params = paramsself.in_channels = params.in_channelsself.out_channels = params.out_channelsif params.hidden_size % params.num_heads != 0:raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") # 隐藏层维度必须能被头数整除pe_dim = params.hidden_size // params.num_heads # 位置编码的维度数与单个自注意力头的维度数相同if sum(params.axes_dim) != pe_dim: # 各个轴的维度之和应该等于位置编码的维度数raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")self.hidden_size = params.hidden_sizeself.num_heads = params.num_headsself.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) # 多维旋转位置编码self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)self.guidance_in = (MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity())self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)self.double_blocks = nn.ModuleList([DoubleStreamBlock(self.hidden_size,self.num_heads,mlp_ratio=params.mlp_ratio,qkv_bias=params.qkv_bias,)for _ in range(params.depth)]) # 双流注意力模块堆self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)for _ in range(params.depth_single_blocks)]) # 单流注意力模块堆self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)def forward(self,img: Tensor, # 重排后的图像张量img_ids: Tensor,txt: Tensor, # t5文本嵌入txt_ids: Tensor,timesteps: Tensor,y: Tensor, # vec # clip文本嵌入guidance: Tensor | None = None,) -> Tensor:if img.ndim != 3 or txt.ndim != 3:raise ValueError("Input img and txt tensors must have 3 dimensions.")# running on sequences imgimg = self.img_in(img)vec = self.time_in(timestep_embedding(timesteps, 256)) # 时间编码if self.params.guidance_embed:if guidance is None:raise ValueError("Didn't get guidance strength for guidance distilled model.")vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # 叠加引导编码vec = vec + self.vector_in(y) # 至此,时间编码、引导编码、clip文本嵌入都融合到vec中txt = self.txt_in(txt)ids = torch.cat((txt_ids, img_ids), dim=1) # 文本位置ids和图像位置ids拼接pe = self.pe_embedder(ids) # 旋转位置编码for block in self.double_blocks:img, txt = block(img=img, txt=txt, vec=vec, pe=pe) # 双流模块img = torch.cat((txt, img), 1) # 文本隐向量和图像隐向量拼接称单一向量for block in self.single_blocks:img = block(img, vec=vec, pe=pe) # 单流模块img = img[:, txt.shape[1] :, ...] # 只使用后半段的图片ids序列img = self.final_layer(img, vec) # (B, img_seq_len, out_channels)return img
双流模块
该模块称为双流的原因就是其内部为图像、文本特征采用单独的模块进行计算。先使用输入的vec
为图片和文本分别预测两个调制模块,然后均先分别使用第一个调值模块的scale
和shift
分量分别处理图片张量、文本张量,再分别应用对应的注意力模块处理q、k、v张量并将图像、文本的q、k、v张量对应拼接,得到最终参与注意力计算的q、k、v张量。注意力计算后,再从结果中拆分出文本、图像注意力输出,再分别使用第二个调制模块配合对应的层归一化、mlp层和原始输入的img
、txt
得到最终的输出img
、txt
向量。
class DoubleStreamBlock(nn.Module):def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):super().__init__()mlp_hidden_dim = int(hidden_size * mlp_ratio)self.num_heads = num_headsself.hidden_size = hidden_sizeself.img_mod = Modulation(hidden_size, double=True)self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.img_mlp = nn.Sequential(nn.Linear(hidden_size, mlp_hidden_dim, bias=True),nn.GELU(approximate="tanh"),nn.Linear(mlp_hidden_dim, hidden_size, bias=True),)self.txt_mod = Modulation(hidden_size, double=True)self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.txt_mlp = nn.Sequential(nn.Linear(hidden_size, mlp_hidden_dim, bias=True),nn.GELU(approximate="tanh"),nn.Linear(mlp_hidden_dim, hidden_size, bias=True),)def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:"""参数:img: 图像张量 [batch, seq_len, hidden_size], 图片latenttxt: 文本张量 [batch, seq_len, hidden_size],prompt经过T5 encoder后的嵌入vec: 调制向量 [batch, hidden_size], 是prompt经过clip、时间经过位置编码,guidance经过位置编码后的拼接张量pe: 位置编码张量 [batch, 1, dim, 2, 2]返回:img: 处理后的图像张量 [batch, seq_len, hidden_size]txt: 处理后的文本张量 [batch, seq_len, hidden_size]"""# 分别预测图片和文本的调制项img_mod1, img_mod2 = self.img_mod(vec)txt_mod1, txt_mod2 = self.txt_mod(vec)# prepare image for attentionimg_modulated = self.img_norm1(img) # 归一化img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift # 调制img_qkv = self.img_attn.qkv(img_modulated) # 单独使用图片自注意力模块中的qkv子模块从拼接在一起的qkvimg_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # 拆分img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # 单独使用图片自注意力模块中的归一化模块# prepare txt for attentiontxt_modulated = self.txt_norm1(txt)txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shifttxt_qkv = self.txt_attn.qkv(txt_modulated)txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)# run actual attentionq = torch.cat((txt_q, img_q), dim=2) # 文本和图片的q拼接k = torch.cat((txt_k, img_k), dim=2) # 文本和图片的k拼接v = torch.cat((txt_v, img_v), dim=2) # 文本和图片的v拼接attn = attention(q, k, v, pe=pe) # 注意力计算 txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # 拆分出文本和图片的注意力结果# calculate the img blocksimg = img + img_mod1.gate * self.img_attn.proj(img_attn) # 先单独使用图片自注意力模块中的投影层转换,再乘上图片调制项的门控系数,最后加上图片调制项的偏移量img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # 与图片调制项中的第二组调制参数组合# calculate the txt blockstxt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)return img, txt
单流模块
单流模块的输入是文本和图像拼接之后的向量,故只使用一个模块组进行计算。整个计算流程与双流模块基本一致,但有一点不同是单流模块中不像常规的transformer block,先进行注意力计算,然后执行mlp层,而是在构建q、k、v张量时就并行预测了mlp的输出,然后对注意力输出进行正则化处理时直接和mlp层内容拼接,也是并行处理。
class SingleStreamBlock(nn.Module):"""A DiT block with parallel linear layers as described inhttps://arxiv.org/abs/2302.05442 and adapted modulation interface."""def __init__(self,hidden_size: int,num_heads: int,mlp_ratio: float = 4.0,qk_scale: float | None = None,):super().__init__()self.hidden_dim = hidden_sizeself.num_heads = num_headshead_dim = hidden_size // num_headsself.scale = qk_scale or head_dim**-0.5self.mlp_hidden_dim = int(hidden_size * mlp_ratio)# qkv and mlp_inself.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) # 并行线性层,qkv转换时同时将mlp输入也预测处理# proj and mlp_out self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) # 并行线性层,最后转换注意力计算时,直接和mlp的输出拼接一并处理self.norm = QKNorm(head_dim)self.hidden_size = hidden_sizeself.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.mlp_act = nn.GELU(approximate="tanh")self.modulation = Modulation(hidden_size, double=False)def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:"""参数:x: 输入张量 [batch, seq_len, hidden_size],经过双流注意力模块后输出的image latent和prompt latent拼接后的张量vec: 调制向量 [batch, hidden_size], 是prompt经过clip、时间经过位置编码,guidance经过位置编码后的拼接张量pe: 位置编码张量 [batch, 1, dim, 2, 2]返回:x: 处理后的张量 [batch, seq_len, hidden_size]"""mod, _ = self.modulation(vec) # 调制,只预测一组调制参数x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift # 调制qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) # 拆分q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)q, k = self.norm(q, k, v)# compute attentionattn = attention(q, k, v, pe=pe)# compute activation in mlp stream, cat again and run second linear layeroutput = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))return x + mod.gate * output
VAE
Flux.1模型中使用到的VAE模型保持了常规的VAE架构,主要由编码器、解码器和重参数化采样层组成;不同点是其中的编码器和解码器均是Unet结构,并在Unet结构的中间层均加入了一个自注意力模块。其他没有太多可说的,具体可参考推理代码中的定义。
生图+采样
与常规的扩散模型相似,流匹配范式采样的原始数据也是纯噪声数据x
,再基于x
、提示词prompts
构建输入,主要是调整图像shape和构建对应的多维绝对位置ids序列、文本嵌入和对应的文本绝对位置ids序列。在使用预定义的调度器构建采样时间步长后,就能开始去噪采样,flow matching范式是直接以线性插值的方式迭代更新。flux backbone预测的是离散序列形式的图像隐向量,先将其解包回图像隐空间,再使用vae的编码器将其解码回像素空间得到最终的生成图片。此过程涉及的细节角度,具体可参考以下代码中的注释,更多细节可进一步参考原始推理代码。想进一步了解Flow matching的读者也可参考笔者之前的文章从扩散模型开始的生成模型范式演变–FM(1)、从扩散模型开始的生成模型范式演变–FM(2)。
# 初始化噪声
def get_noise(num_samples: int,height: int,width: int,device: torch.device,dtype: torch.dtype,seed: int,
): # 采样图片隐空间尺寸的噪声return torch.randn( # 从标准正态分布中采样随机噪声num_samples,16,# allow for packing2 * math.ceil(height / 16),2 * math.ceil(width / 16),dtype=dtype,generator=torch.Generator(device="cpu").manual_seed(seed),).to(device)# 常规的数据准备函数
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:bs, c, h, w = img.shape # 此处的img的shape与经过vae编码后的隐向量shape相同if bs == 1 and not isinstance(prompt, str):bs = len(prompt)# 图像重排和批次扩展;将图像隐向量在平面维度上分割为2*2的patch,再经过展平实现长度为H/2*W/2的patches序列,即完成了图像隐向量离散序列化,每个patch的维度是C*2*2img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # [B, C, H, W] -> [B, H/2*W/2, C*2*2]if img.shape[0] == 1 and bs > 1:img = repeat(img, "1 ... -> bs ...", bs=bs)# 生成图像三维位置idsimg_ids = torch.zeros(h // 2, w // 2, 3) # 因为将图像隐向量分割为2*2的patch,以空间角度位置编码的角度来看最后一个维度应该为2,此处为3的原因是后续会和文本位置ids拼接,在最前面添加一个区域模态的维度img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] # 行索引,[0, 1, 2, ..., H/2-1]img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] # 列索引,[0, 1, 2, ..., W/2-1]# 将三维位置ids拉平,再补齐batchimg_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) # [H/2, W/2, 3] -> [B, H/2*W/2, 3]if isinstance(prompt, str):prompt = [prompt]txt = t5(prompt) # t5 encoder编码后的文本嵌入if txt.shape[0] == 1 and bs > 1:txt = repeat(txt, "1 ... -> bs ...", bs=bs)txt_ids = torch.zeros(bs, txt.shape[1], 3) # 为了能与图像的位置编码拼接,最后一个维度也是3vec = clip(prompt) # clip encoder编码后的文本嵌入if vec.shape[0] == 1 and bs > 1:vec = repeat(vec, "1 ... -> bs ...", bs=bs)return {"img": img, # 重排后的图像张量"img_ids": img_ids.to(img.device), # 图像多维位置ids"txt": txt.to(img.device), # t5文本嵌入"txt_ids": txt_ids.to(img.device), # 文本位置ids"vec": vec.to(img.device), # clip文本向量}# 构建经过两个点(x1,y1)和(x2,y2)的线性函数
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:m = (y2 - y1) / (x2 - x1)b = y1 - m * x1return lambda x: m * x + bdef get_schedule(num_steps: int,image_seq_len: int,base_shift: float = 0.5,max_shift: float = 1.15,shift: bool = True,
) -> list[float]:# extra step for zerotimesteps = torch.linspace(1, 0, num_steps + 1) # 从1到0的num_steps+1个等差数列# shifting the schedule to favor high timesteps for higher signal imagesif shift:# estimate mu based on linear estimation between two pointsmu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)timesteps = time_shift(mu, 1.0, timesteps)return timesteps.tolist()# 去噪
def denoise(model: Flux,# model inputimg: Tensor,img_ids: Tensor,txt: Tensor,txt_ids: Tensor,vec: Tensor,# sampling parameterstimesteps: list[float],guidance: float = 4.0,# extra img tokens (channel-wise)img_cond: Tensor | None = None,# extra img tokens (sequence-wise)img_cond_seq: Tensor | None = None,img_cond_seq_ids: Tensor | None = None,
):# this is ignored for schnellguidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) # 创建一个长度为batch size的一维张量,所有值都是guidancefor t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # 创建一个长度为batch size的一维张量,所有值都是t_currimg_input = imgimg_input_ids = img_idsif img_cond is not None:img_input = torch.cat((img, img_cond), dim=-1)if img_cond_seq is not None:assert (img_cond_seq_ids is not None), "You need to provide either both or neither of the sequence conditioning"img_input = torch.cat((img_input, img_cond_seq), dim=1)img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)pred = model(img=img_input,img_ids=img_input_ids,txt=txt,txt_ids=txt_ids,y=vec,timesteps=t_vec,guidance=guidance_vec,) # 使用flux backbone预测flow matching范式中当前时间步的移动速度if img_input_ids is not None:pred = pred[:, : img.shape[1]] # 只使用前半段的图片ids序列img = img + (t_prev - t_curr) * pred # flow matching范式更新时就直接以进行线性插值,即新的图像值就是当前预测值和上一步图像值的插值return img