Flux.1系列模型解析--Flux.1 Tools
文章目录
- 简介
- Fill
- Redux
- Depth/Canny
简介
Flux.1模型的基础能力已经很强,但是局部生成、控制生成等方面仍不足,bfl随进一步训练,开发了Flux.1 Tools系列模型,包含四个模型,具体情况如下。
- Fill:根据文本描述和二进制掩码编辑或扩展输入图像,即Inpainting和Outpainting,是一个基模型
- Redux:一个能对输入图片进行细微变化或调整的Adapter模型,可以和所有Flux.1基模型组合使用
- Depth:可接受条件图像的深度信息控制生成图片,有基模型或lora模型
- Canny:可接受条件图像的canny线条信息控制生成图片,有基模型或lora模型
Fill
与常规的painting模型相同,Flux.1 Fill dev模型基于Flux.1 dev全量微调而来,并且因为掩码图片mask的引入,flux backbone的in_channels参数与Flux.1 dev初始化时不同。Flux.1 Fill dev模型采样时的图片特征具体构建步骤是,先将条件图片和掩码图片(可认为通道数为1)的像素值归一化,然后将条件图片和掩码图片耦合得到条件图片,即确定选区,然后使用VAE的编码器从条件图片中提取条件特征向量,假设像素空间中图片的尺寸为 [H,W,3][H,W,3][H,W,3],那么条件特征相邻的尺寸为 [H8,W8,16][\frac{H}{8},\frac{W}{8},16][8H,8W,16];再对条件特征向量进行分块、离散为序列,即将其划分为 2∗22*22∗2的patch并拉长为序列,尺寸变为[H16,W16,64][\frac{H}{16},\frac{W}{16},64][16H,16W,64]。此外,掩码图片mask也应该包含在输入中,对于mask并没有进行编码进行特征提取,故为了保证其能与离散后的特征向量在最后一个维度拼接,直接将其尺寸转换为[H16,W16,256][\frac{H}{16},\frac{W}{16},256][16H,16W,256]。将条件特征向量和mask特征向量在最后一个维度拼接,得到的最终条件特征向量的尺寸为[H16,W16,320][\frac{H}{16},\frac{W}{16},320][16H,16W,320],推理是其还会和初始噪声在最后一个维度拼接,即最终输入的尺寸为[H16,W16,384][\frac{H}{16},\frac{W}{16},384][16H,16W,384],即Flux.1 Fill dev模型初始化时in_channels参数值为384,而Flux.1 dev模型的in_channels参数值为64。
注意,上述尺寸变化中均未考虑batch size。搞清楚了输入的构建方式,再就是正常的使用Flow matching范式进行全量微调训练。
Redux
Flux.1 Redux dev模型是一个Adapter模型,其本质是两个线形层,和google/siglip-so400m-patch14-384模型组合使用。构建输入时,先使用siglip的视觉编码模型从条件图片中提取图片特征,然后redux包含两个的两个线性层,以siglip提取的图片特征为输入,通过先升维、再降维,得到最终的图片条件向量。与其他条件输入不同,redux提取的条件向量不与代表图片的初始噪声拼接,而是和T5 Encoder从提示词中提取的文本编码特征在长度维度上进行拼接,故redux中升维线性层的输入维度与siglip视觉编码器输出维度相同,为1152;而降维线性层的输出维度与T5 Encoder输出维度相同,为4096。redux的特征提取类具体实现如以下代码所示,其与其他Flux.1基础模型配合使用,故采样时其他细节基本一致,其可以融入到更复杂的流程中,通过提示词和图像引导出重绘、风格转换、材质转换等各种功能。
class ReduxImageEncoder(nn.Module):siglip_model_name = "google/siglip-so400m-patch14-384"def __init__(self,device,redux_path: str, # redux模型路径redux_dim: int = 1152, # 与Siglip的视觉编码器输出维度相同txt_in_features: int = 4096, # 与t5的输出维度相同dtype=torch.bfloat16,) -> None:super().__init__()self.redux_dim = redux_dimself.device = device if isinstance(device, torch.device) else torch.device(device)self.dtype = dtypewith self.device: # 直接在指定设备上创建模块self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)sd = load_sft(redux_path, device=str(device))missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) # 此处是对redux_up和redux_down的模块加载权重print_load_warning(missing, unexpected)self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)def __call__(self, x: Image.Image) -> torch.Tensor:imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) # 图像预处理_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state # 图像编码projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) # 投影;先升维,再激活,再降维return projected_x
Depth/Canny
与Fill和Redux不同,Depth/Canny有两个实现方式,除了对Flux.1 dev模型全量微调外,还基于全量微调后的模型提取了相应的lora模型,故最终有以下四个模型变体。非Lora模型与Flux.1 dev一样是基模型,而lora模型需要和基模型组合使用,以下的两个lora模型要和Flux.1 dev一起使用。
- FLUX.1-Depth-dev
- FLUX.1-Depth-dev-lora
- FLUX.1-Canny-dev
- FLUX.1-Canny-dev-lora
上述的Depth/Canny模型变体可以从参考图片中提取图片控制隐向量,采样时与初始噪声在最后一个维度拼接作为flux backbone的输入;为了初始噪声和参考图片的隐向量可以拼接,使用VAE的编码器提取图片特征前会将其尺寸调整与生成的目标尺寸相同,在拼接后最后的特征维度大小范围,故flux backbone在初始化时in_channels参数是正常使用时的两倍,即从64变大至128。
基于推理代码中控制图片的输入方式,FLUX.1-Depth-dev、FLUX.1-Canny-dev全量训练时也应如此,将原始图片和控制图片使用VAE编码器提取特征向量后在最后一个维度拼接,在其基础上添加噪声后进行Flow matching范式全量微调训练;全量微调的训练结束后,在Flux.1 dev的所有线性层添加lora模块,通过蒸馏全量微调的模型得到对应的lora模型。
对于全量微调的模型,模型架构方便的变动就是前面提到的flux backbone初始化时in_channels参数翻倍,而lora模块的添加更复杂一些,但实际实现也比较简单。如常规lora配置的方式相同,Flux.1 dev模型配置lora模块时只对架构中的线形层添加,flux backbone初始化时通过递归替换的方式将所有常规的线性层替换为自定义的带有lora模块的线性层LinearLora,该类包含常规线性层和对应的lora模块中的两个降维、升维线性层。LinearLora类计算时有两个分支,一条是原始线性层的计算,另一条分支则是先降后升两个线性层计算,然后通过一个缩放参数加权后与原始分支计算结束相加得到最终的输出结果,具体实现代码如下。
# Lora模块相关实现
def replace_linear_with_lora( # 将module中的所有Linear层替换为LinearLora层module: nn.Module,max_rank: int, # Lora最大的秩scale: float = 1.0, # Lora的缩放因子
) -> None:for name, child in module.named_children(): # 遍历所有子模块if isinstance(child, nn.Linear): # 如果当前模块为线性层Linearnew_lora = LinearLora(in_features=child.in_features, # 保持输入维度不变out_features=child.out_features, # 保持输出维度不变bias=child.bias, # 保持偏置不变rank=max_rank, # 设置Lora的秩scale=scale, # 设置Lora的缩放因子dtype=child.weight.dtype, # 保持数据类型不变device=child.weight.device, # 保持设备位置不变) # 初始化一个LinearLora层# 将原始Linear层中的权重复制给新初始化的LinearLora层new_lora.weight = child.weightnew_lora.bias = child.bias if child.bias is not None else Nonesetattr(module, name, new_lora) # 将新初始化的LinearLora层替换到module中else:replace_linear_with_lora(module=child,max_rank=max_rank,scale=scale,) # 递归替换子模块中的Linear层class LinearLora(nn.Linear):def __init__(self,in_features: int,out_features: int,bias: bool,rank: int,dtype: torch.dtype,device: torch.device,lora_bias: bool = True,scale: float = 1.0,*args,**kwargs,) -> None:super().__init__(in_features=in_features,out_features=out_features,bias=bias is not None,device=device,dtype=dtype,*args,**kwargs,)assert isinstance(scale, float), "scale must be a float"self.scale = scaleself.rank = rankself.lora_bias = lora_biasself.dtype = dtypeself.device = deviceif rank > (new_rank := min(self.out_features, self.in_features)): # 确保rank不超过输入、输出维度的最小值,就是要同时小于out_features和in_featuresself.rank = new_rank# 初始化Lora的A和B矩阵self.lora_A = nn.Linear(in_features=in_features,out_features=self.rank,bias=False,dtype=dtype,device=device,)self.lora_B = nn.Linear(in_features=self.rank,out_features=out_features,bias=self.lora_bias,dtype=dtype,device=device,)def set_scale(self, scale: float) -> None:assert isinstance(scale, float), "scalar value must be a float"self.scale = scaledef forward(self, input: torch.Tensor) -> torch.Tensor:'''输入 x│├─────────────────┐│ │原始线性层 LoRA路径│ ││ 降维(lora_A)│ ││ rank维度│ ││ 升维(lora_B)│ ││ 缩放(scale)│ │└────────┬────────┘合并│输出'''base_out = super().forward(input) # 先进行原始的线性变换_lora_out_B = self.lora_B(self.lora_A(input)) # 计算Lora的输出lora_update = _lora_out_B * self.scale # 计算Lora的更新量return base_out + lora_update# 带有Lora模块的FLux模型实现
class FluxLoraWrapper(Flux):def __init__(self,lora_rank: int = 128,lora_scale: float = 1.0,*args,**kwargs,) -> None:super().__init__(*args, **kwargs)self.lora_rank = lora_rankreplace_linear_with_lora(self,max_rank=lora_rank,scale=lora_scale,) # 将模型中的所有线性层替换为Lora线性层def set_lora_scale(self, scale: float) -> None:for module in self.modules():if isinstance(module, LinearLora):module.set_scale(scale=scale)