【diffusers 进阶(十四)】权重读取,查看 Lora 具体加在哪里和 Rank ‘秩’ 是多少?以 OminiControl 为例
系列文章目录
- 【diffusers 极速入门(一)】pipeline 实际调用的是什么? call 方法!
- 【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数
- 【diffusers极速入门(三)】生成的图像尺寸与 UNet 和 VAE 之间的关系
- 【diffusers极速入门(四)】EMA 操作是什么?
- 【diffusers极速入门(五)】扩散模型中的 Scheduler(noise_scheduler)的作用是什么?
- 【diffusers极速入门(六)】缓存梯度和自动放缩学习率以及代码详解
- 【diffusers极速入门(七)】Classifier-Free Guidance (CFG)直观理解以及对应代码
- 【diffusers极速入门(八)】GPU 显存节省(减少内存使用)技巧总结
- 【diffusers极速入门(九)】GPU 显存节省(减少内存使用)代码总结
- 【diffusers极速入门(十)】Flux-pipe 推理,完美利用任何显存大小,GPU显存节省终极方案(附代码)
- 【diffusers 进阶(十一)】Lora 具体是怎么加入模型的(推理代码篇上)OminiControl
- 【diffusers 进阶(十二)】Lora 具体是怎么加入模型的(推理代码篇下)OminiControl
- 【diffusers 进阶(十三)】AdaLayerNormZero 与 AdaLayerNormZeroSingle 代码详细分析
文章目录
- 系列文章目录
- 一、LoRA 具体加在哪里?
- 二、读取 Rank 的具体方法
一、LoRA 具体加在哪里?
import safetensors
from safetensors.torch import load_file
import os
import sys
# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) # 将项目根目录添加到Python路径中
import torch
from diffusers.pipelines import FluxPipeline
from src.flux.condition import Condition
from PIL import Image
from src.flux.generate import generate, seed_everything
# 加载权重文件
weights_1 = load_file("/OminiControl/experimental/subject.safetensors")
weights_2 = load_file("/OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("/OminiControl/omini/subject_1024_beta.safetensors")
save_path = '/OminiControl'
# 将所有键名存到 save_path
import json
with open(os.path.join(save_path, 'weights_keys.json'), 'w') as f:
# 按照层级和类型组织权重键名
organized_data = {}
for weight_name, keys in [("subject", weights_1.keys()),
("subject_512", weights_2.keys()),
("subject_1024_beta", weights_3.keys())]:
# 创建分类字典
categories = {}
for key in keys:
# 提取主要组件(例如:transformer.single_transformer_blocks.0.attn)
parts = key.split('.')
if len(parts) >= 3:
category = '.'.join(parts[:-2]) # 去掉最后两个部分(通常是 lora_A/B.weight)
if category not in categories:
categories[category] = []
categories[category].append(key)
organized_data[weight_name] = categories
json.dump(organized_data, f, indent=4, sort_keys=False)
打印出想要的权重后,可以清楚地看出 lora 加在哪些层上。
组件类型 | subject | subject_512 | subject_1024_beta |
---|---|---|---|
transformer.transformer_blocks (MMDiT ) | ✓ 完整微调 (to_k, to_q, to_v, to_out.0, ff.net.2, norm1.linear) | ✓ 包含(只有to_k, to_q, norm1.linear) | ✓ 包含(只有to_k, to_q, norm1.linear ) |
transformer.single_transformer_blocks(Single-DiT ) | ✓ 完整微调 (to_k, to_q, to_v, proj_mlp, proj_out, norm.linear) | ✓ 部分微调 (只有to_k, to_q, norm.linear) | ✓ 部分微调 (只有to_k, to_q, norm.linear) |
transformer.x_embedder | ✓ 包含 | ✗ 不包含 | ✗ 不包含 |
分辨率/序列长度 | 未指定 (可能是通用版本) | 512 (针对中等分辨率) | 1024 Beta (针对高分辨率,实验性) |
内存大小 | 29.1 MB | 50.5 MB | 50.5 MB |
Rank | 4 | 16 | 16 |
二、读取 Rank 的具体方法
读取并打印 lora_A 的输出特征数(即 rank)或者 lora_B 的输入特征数即可。
import safetensors
from safetensors.torch import load_file
import os
import sys
import torch
# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root) # 将项目根目录添加到Python路径中
# 加载权重文件
weights_1 = load_file("OminiControl/experimental/subject.safetensors")
weights_2 = load_file("OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("OminiControl/omini/subject_1024_beta.safetensors")
# 函数:获取LoRA的rank值
def get_lora_rank(weights):
# 查找第一个lora_A权重
lora_a_key = None
for key in weights.keys():
if 'lora_A.weight' in key:
lora_a_key = key
break
if lora_a_key is None:
return "未找到lora_A权重"
# 获取对应的lora_B权重键
lora_b_key = lora_a_key.replace('lora_A.weight', 'lora_B.weight')
if lora_b_key not in weights:
return f"找到lora_A权重 {lora_a_key},但未找到对应的lora_B权重"
# 获取权重张量的形状
lora_a_shape = weights[lora_a_key].shape
lora_b_shape = weights[lora_b_key].shape
# LoRA的rank是lora_A的输出特征数或lora_B的输入特征数
# 对于safetensors格式,我们需要检查张量的维度
if len(lora_a_shape) == 2: # 常规Linear层权重
rank = lora_a_shape[0] # lora_A的输出特征数
else: # 可能是卷积层或其他特殊结构
print(f"警告:非标准LoRA权重形状: {lora_a_shape}")
rank = lora_a_shape[-1] # 尝试获取最后一个维度
# 验证lora_B的输入特征数是否与lora_A的输出特征数相同
if len(lora_b_shape) == 2 and lora_b_shape[1] != rank:
return f"警告:lora_A输出特征数 ({rank}) 与lora_B输入特征数 ({lora_b_shape[1]}) 不匹配"
return rank
# 检查每个权重文件的LoRA rank
print("subject.safetensors 的 LoRA rank:", get_lora_rank(weights_1))
print("subject_512.safetensors 的 LoRA rank:", get_lora_rank(weights_2))
print("subject_1024_beta.safetensors 的 LoRA rank:", get_lora_rank(weights_3))
# 可选:检查所有LoRA权重的rank是否一致
def check_all_ranks(weights, name):
ranks = {}
for key in weights.keys():
if 'lora_A.weight' in key:
lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')
if lora_b_key in weights:
# 获取正确的rank值
lora_a_shape = weights[key].shape
if len(lora_a_shape) == 2: # 常规Linear层权重
rank = lora_a_shape[0] # lora_A的输出特征数
else: # 可能是卷积层或其他特殊结构
rank = lora_a_shape[-1] # 尝试获取最后一个维度
ranks[key] = rank
# 检查是否所有rank都相同
unique_ranks = set(ranks.values())
if len(unique_ranks) == 1:
print(f"{name} 中所有LoRA权重的rank都是 {next(iter(unique_ranks))}")
else:
print(f"{name} 中LoRA权重的rank不一致:")
for rank in sorted(unique_ranks):
count = list(ranks.values()).count(rank)
print(f" Rank {rank}: {count} 个权重")
print("\n检查每个文件中所有LoRA权重的rank:")
check_all_ranks(weights_1, "subject.safetensors")
check_all_ranks(weights_2, "subject_512.safetensors")
check_all_ranks(weights_3, "subject_1024_beta.safetensors")
# 打印一些示例权重的形状,以便调试
print("\n示例权重形状:")
for weights, name in [(weights_1, "subject"), (weights_2, "subject_512"), (weights_3, "subject_1024_beta")]:
for key in list(weights.keys())[:2]: # 只打印前两个键
if 'lora_A.weight' in key:
lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')
print(f"{name} - {key}: {weights[key].shape}")
if lora_b_key in weights:
print(f"{name} - {lora_b_key}: {weights[lora_b_key].shape}")
subject.safetensors 的 LoRA rank: 4
subject_512.safetensors 的 LoRA rank: 16
subject_1024_beta.safetensors 的 LoRA rank: 16
检查每个文件中所有LoRA权重的rank:
subject.safetensors 中所有LoRA权重的rank都是 4
subject_512.safetensors 中所有LoRA权重的rank都是 16
subject_1024_beta.safetensors 中所有LoRA权重的rank都是 16
示例权重形状:
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([4, 3072])
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 4])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])