广告网站设计方案网盘搜索引擎
系列文章目录
- 【diffusers 极速入门(一)】pipeline 实际调用的是什么? call 方法!
- 【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数
- 【diffusers极速入门(三)】生成的图像尺寸与 UNet 和 VAE 之间的关系
- 【diffusers极速入门(四)】EMA 操作是什么?
- 【diffusers极速入门(五)】扩散模型中的 Scheduler(noise_scheduler)的作用是什么?
- 【diffusers极速入门(六)】缓存梯度和自动放缩学习率以及代码详解
- 【diffusers极速入门(七)】Classifier-Free Guidance (CFG)直观理解以及对应代码
- 【diffusers极速入门(八)】GPU 显存节省(减少内存使用)技巧总结
- 【diffusers极速入门(九)】GPU 显存节省(减少内存使用)代码总结
- 【diffusers极速入门(十)】Flux-pipe 推理,完美利用任何显存大小,GPU显存节省终极方案(附代码)
- 【diffusers 进阶(十一)】Lora 具体是怎么加入模型的(推理代码篇上)OminiControl
- 【diffusers 进阶(十二)】Lora 具体是怎么加入模型的(推理代码篇下)OminiControl
- 【diffusers 进阶(十三)】AdaLayerNormZero 与 AdaLayerNormZeroSingle 代码详细分析
文章目录
- 系列文章目录
- 一、LoRA 具体加在哪里?
- 二、读取 Rank 的具体方法
一、LoRA 具体加在哪里?
import safetensors
from safetensors.torch import load_file
import os
import sys# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) # 将项目根目录添加到Python路径中import torch
from diffusers.pipelines import FluxPipeline
from src.flux.condition import Condition
from PIL import Image
from src.flux.generate import generate, seed_everything# 加载权重文件
weights_1 = load_file("/OminiControl/experimental/subject.safetensors")
weights_2 = load_file("/OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("/OminiControl/omini/subject_1024_beta.safetensors")save_path = '/OminiControl'# 将所有键名存到 save_path
import json
with open(os.path.join(save_path, 'weights_keys.json'), 'w') as f:# 按照层级和类型组织权重键名organized_data = {}for weight_name, keys in [("subject", weights_1.keys()), ("subject_512", weights_2.keys()), ("subject_1024_beta", weights_3.keys())]:# 创建分类字典categories = {}for key in keys:# 提取主要组件(例如:transformer.single_transformer_blocks.0.attn)parts = key.split('.')if len(parts) >= 3:category = '.'.join(parts[:-2]) # 去掉最后两个部分(通常是 lora_A/B.weight)if category not in categories:categories[category] = []categories[category].append(key)organized_data[weight_name] = categoriesjson.dump(organized_data, f, indent=4, sort_keys=False)
打印出想要的权重后,可以清楚地看出 lora 加在哪些层上。
组件类型 | subject | subject_512 | subject_1024_beta |
---|---|---|---|
transformer.transformer_blocks (MMDiT ) | ✓ 完整微调 (to_k, to_q, to_v, to_out.0, ff.net.2, norm1.linear) | ✓ 包含(只有to_k, to_q, norm1.linear) | ✓ 包含(只有to_k, to_q, norm1.linear ) |
transformer.single_transformer_blocks(Single-DiT ) | ✓ 完整微调 (to_k, to_q, to_v, proj_mlp, proj_out, norm.linear) | ✓ 部分微调 (只有to_k, to_q, norm.linear) | ✓ 部分微调 (只有to_k, to_q, norm.linear) |
transformer.x_embedder | ✓ 包含 | ✗ 不包含 | ✗ 不包含 |
分辨率/序列长度 | 未指定 (可能是通用版本) | 512 (针对中等分辨率) | 1024 Beta (针对高分辨率,实验性) |
内存大小 | 29.1 MB | 50.5 MB | 50.5 MB |
Rank | 4 | 16 | 16 |
二、读取 Rank 的具体方法
读取并打印 lora_A 的输出特征数(即 rank)或者 lora_B 的输入特征数即可。
import safetensors
from safetensors.torch import load_file
import os
import sys
import torch# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) # 将项目根目录添加到Python路径中# 加载权重文件
weights_1 = load_file("OminiControl/experimental/subject.safetensors")
weights_2 = load_file("OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("OminiControl/omini/subject_1024_beta.safetensors")# 函数:获取LoRA的rank值
def get_lora_rank(weights):# 查找第一个lora_A权重lora_a_key = Nonefor key in weights.keys():if 'lora_A.weight' in key:lora_a_key = keybreakif lora_a_key is None:return "未找到lora_A权重"# 获取对应的lora_B权重键lora_b_key = lora_a_key.replace('lora_A.weight', 'lora_B.weight')if lora_b_key not in weights:return f"找到lora_A权重 {lora_a_key},但未找到对应的lora_B权重"# 获取权重张量的形状lora_a_shape = weights[lora_a_key].shapelora_b_shape = weights[lora_b_key].shape# LoRA的rank是lora_A的输出特征数或lora_B的输入特征数# 对于safetensors格式,我们需要检查张量的维度if len(lora_a_shape) == 2: # 常规Linear层权重rank = lora_a_shape[0] # lora_A的输出特征数else: # 可能是卷积层或其他特殊结构print(f"警告:非标准LoRA权重形状: {lora_a_shape}")rank = lora_a_shape[-1] # 尝试获取最后一个维度# 验证lora_B的输入特征数是否与lora_A的输出特征数相同if len(lora_b_shape) == 2 and lora_b_shape[1] != rank:return f"警告:lora_A输出特征数 ({rank}) 与lora_B输入特征数 ({lora_b_shape[1]}) 不匹配"return rank# 检查每个权重文件的LoRA rank
print("subject.safetensors 的 LoRA rank:", get_lora_rank(weights_1))
print("subject_512.safetensors 的 LoRA rank:", get_lora_rank(weights_2))
print("subject_1024_beta.safetensors 的 LoRA rank:", get_lora_rank(weights_3))# 可选:检查所有LoRA权重的rank是否一致
def check_all_ranks(weights, name):ranks = {}for key in weights.keys():if 'lora_A.weight' in key:lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')if lora_b_key in weights:# 获取正确的rank值lora_a_shape = weights[key].shapeif len(lora_a_shape) == 2: # 常规Linear层权重rank = lora_a_shape[0] # lora_A的输出特征数else: # 可能是卷积层或其他特殊结构rank = lora_a_shape[-1] # 尝试获取最后一个维度ranks[key] = rank# 检查是否所有rank都相同unique_ranks = set(ranks.values())if len(unique_ranks) == 1:print(f"{name} 中所有LoRA权重的rank都是 {next(iter(unique_ranks))}")else:print(f"{name} 中LoRA权重的rank不一致:")for rank in sorted(unique_ranks):count = list(ranks.values()).count(rank)print(f" Rank {rank}: {count} 个权重")print("\n检查每个文件中所有LoRA权重的rank:")
check_all_ranks(weights_1, "subject.safetensors")
check_all_ranks(weights_2, "subject_512.safetensors")
check_all_ranks(weights_3, "subject_1024_beta.safetensors")# 打印一些示例权重的形状,以便调试
print("\n示例权重形状:")
for weights, name in [(weights_1, "subject"), (weights_2, "subject_512"), (weights_3, "subject_1024_beta")]:for key in list(weights.keys())[:2]: # 只打印前两个键if 'lora_A.weight' in key:lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')print(f"{name} - {key}: {weights[key].shape}")if lora_b_key in weights:print(f"{name} - {lora_b_key}: {weights[lora_b_key].shape}")
subject.safetensors 的 LoRA rank: 4
subject_512.safetensors 的 LoRA rank: 16
subject_1024_beta.safetensors 的 LoRA rank: 16检查每个文件中所有LoRA权重的rank:
subject.safetensors 中所有LoRA权重的rank都是 4
subject_512.safetensors 中所有LoRA权重的rank都是 16
subject_1024_beta.safetensors 中所有LoRA权重的rank都是 16示例权重形状:
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([4, 3072])
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 4])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])