当前位置: 首页 > news >正文

【diffusers 进阶(十四)】权重读取,查看 Lora 具体加在哪里和 Rank ‘秩’ 是多少?以 OminiControl 为例

系列文章目录

  • 【diffusers 极速入门(一)】pipeline 实际调用的是什么? call 方法!
  • 【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数
  • 【diffusers极速入门(三)】生成的图像尺寸与 UNet 和 VAE 之间的关系
  • 【diffusers极速入门(四)】EMA 操作是什么?
  • 【diffusers极速入门(五)】扩散模型中的 Scheduler(noise_scheduler)的作用是什么?
  • 【diffusers极速入门(六)】缓存梯度和自动放缩学习率以及代码详解
  • 【diffusers极速入门(七)】Classifier-Free Guidance (CFG)直观理解以及对应代码
  • 【diffusers极速入门(八)】GPU 显存节省(减少内存使用)技巧总结
  • 【diffusers极速入门(九)】GPU 显存节省(减少内存使用)代码总结
  • 【diffusers极速入门(十)】Flux-pipe 推理,完美利用任何显存大小,GPU显存节省终极方案(附代码)
  • 【diffusers 进阶(十一)】Lora 具体是怎么加入模型的(推理代码篇上)OminiControl
  • 【diffusers 进阶(十二)】Lora 具体是怎么加入模型的(推理代码篇下)OminiControl
  • 【diffusers 进阶(十三)】AdaLayerNormZero 与 AdaLayerNormZeroSingle 代码详细分析

文章目录

  • 系列文章目录
  • 一、LoRA 具体加在哪里?
  • 二、读取 Rank 的具体方法


一、LoRA 具体加在哪里?

import safetensors
from safetensors.torch import load_file
import os
import sys

# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)  # 将项目根目录添加到Python路径中

import torch
from diffusers.pipelines import FluxPipeline
from src.flux.condition import Condition
from PIL import Image
from src.flux.generate import generate, seed_everything

# 加载权重文件
weights_1 = load_file("/OminiControl/experimental/subject.safetensors")
weights_2 = load_file("/OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("/OminiControl/omini/subject_1024_beta.safetensors")

save_path = '/OminiControl'

# 将所有键名存到 save_path
import json
with open(os.path.join(save_path, 'weights_keys.json'), 'w') as f:
    # 按照层级和类型组织权重键名
    organized_data = {}
    
    for weight_name, keys in [("subject", weights_1.keys()), 
                             ("subject_512", weights_2.keys()), 
                             ("subject_1024_beta", weights_3.keys())]:
        # 创建分类字典
        categories = {}
        for key in keys:
            # 提取主要组件(例如:transformer.single_transformer_blocks.0.attn)
            parts = key.split('.')
            if len(parts) >= 3:
                category = '.'.join(parts[:-2])  # 去掉最后两个部分(通常是 lora_A/B.weight)
                if category not in categories:
                    categories[category] = []
                categories[category].append(key)
        
        organized_data[weight_name] = categories
    
    json.dump(organized_data, f, indent=4, sort_keys=False)

打印出想要的权重后,可以清楚地看出 lora 加在哪些层上。

组件类型subjectsubject_512subject_1024_beta
transformer.transformer_blocks (MMDiT✓ 完整微调
(to_k, to_q, to_v, to_out.0, ff.net.2, norm1.linear)
✓ 包含(只有to_k, to_q, norm1.linear)✓ 包含(只有to_k, to_q, norm1.linear)
transformer.single_transformer_blocks(Single-DiT✓ 完整微调
(to_k, to_q, to_v, proj_mlp, proj_out, norm.linear)
✓ 部分微调
(只有to_k, to_q, norm.linear)
✓ 部分微调
(只有to_k, to_q, norm.linear)
transformer.x_embedder✓ 包含✗ 不包含✗ 不包含
分辨率/序列长度未指定
(可能是通用版本)
512
(针对中等分辨率)
1024 Beta
(针对高分辨率,实验性)
内存大小29.1 MB50.5 MB50.5 MB
Rank41616

二、读取 Rank 的具体方法

读取并打印 lora_A 的输出特征数(即 rank)或者 lora_B 的输入特征数即可。

import safetensors
from safetensors.torch import load_file
import os
import sys
import torch

# 获取项目根目录的绝对路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)  # 将项目根目录添加到Python路径中

# 加载权重文件
weights_1 = load_file("OminiControl/experimental/subject.safetensors")
weights_2 = load_file("OminiControl/omini/subject_512.safetensors")
weights_3 = load_file("OminiControl/omini/subject_1024_beta.safetensors")

# 函数:获取LoRA的rank值
def get_lora_rank(weights):
    # 查找第一个lora_A权重
    lora_a_key = None
    for key in weights.keys():
        if 'lora_A.weight' in key:
            lora_a_key = key
            break
    
    if lora_a_key is None:
        return "未找到lora_A权重"
    
    # 获取对应的lora_B权重键
    lora_b_key = lora_a_key.replace('lora_A.weight', 'lora_B.weight')
    
    if lora_b_key not in weights:
        return f"找到lora_A权重 {lora_a_key},但未找到对应的lora_B权重"
    
    # 获取权重张量的形状
    lora_a_shape = weights[lora_a_key].shape
    lora_b_shape = weights[lora_b_key].shape
    
    # LoRA的rank是lora_A的输出特征数或lora_B的输入特征数
    # 对于safetensors格式,我们需要检查张量的维度
    if len(lora_a_shape) == 2:  # 常规Linear层权重
        rank = lora_a_shape[0]  # lora_A的输出特征数
    else:  # 可能是卷积层或其他特殊结构
        print(f"警告:非标准LoRA权重形状: {lora_a_shape}")
        rank = lora_a_shape[-1]  # 尝试获取最后一个维度
    
    # 验证lora_B的输入特征数是否与lora_A的输出特征数相同
    if len(lora_b_shape) == 2 and lora_b_shape[1] != rank:
        return f"警告:lora_A输出特征数 ({rank}) 与lora_B输入特征数 ({lora_b_shape[1]}) 不匹配"
    
    return rank

# 检查每个权重文件的LoRA rank
print("subject.safetensors 的 LoRA rank:", get_lora_rank(weights_1))
print("subject_512.safetensors 的 LoRA rank:", get_lora_rank(weights_2))
print("subject_1024_beta.safetensors 的 LoRA rank:", get_lora_rank(weights_3))

# 可选:检查所有LoRA权重的rank是否一致
def check_all_ranks(weights, name):
    ranks = {}
    for key in weights.keys():
        if 'lora_A.weight' in key:
            lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')
            if lora_b_key in weights:
                # 获取正确的rank值
                lora_a_shape = weights[key].shape
                if len(lora_a_shape) == 2:  # 常规Linear层权重
                    rank = lora_a_shape[0]  # lora_A的输出特征数
                else:  # 可能是卷积层或其他特殊结构
                    rank = lora_a_shape[-1]  # 尝试获取最后一个维度
                
                ranks[key] = rank
    
    # 检查是否所有rank都相同
    unique_ranks = set(ranks.values())
    if len(unique_ranks) == 1:
        print(f"{name} 中所有LoRA权重的rank都是 {next(iter(unique_ranks))}")
    else:
        print(f"{name} 中LoRA权重的rank不一致:")
        for rank in sorted(unique_ranks):
            count = list(ranks.values()).count(rank)
            print(f"  Rank {rank}: {count} 个权重")

print("\n检查每个文件中所有LoRA权重的rank:")
check_all_ranks(weights_1, "subject.safetensors")
check_all_ranks(weights_2, "subject_512.safetensors")
check_all_ranks(weights_3, "subject_1024_beta.safetensors")

# 打印一些示例权重的形状,以便调试
print("\n示例权重形状:")
for weights, name in [(weights_1, "subject"), (weights_2, "subject_512"), (weights_3, "subject_1024_beta")]:
    for key in list(weights.keys())[:2]:  # 只打印前两个键
        if 'lora_A.weight' in key:
            lora_b_key = key.replace('lora_A.weight', 'lora_B.weight')
            print(f"{name} - {key}: {weights[key].shape}")
            if lora_b_key in weights:
                print(f"{name} - {lora_b_key}: {weights[lora_b_key].shape}")
subject.safetensors 的 LoRA rank: 4
subject_512.safetensors 的 LoRA rank: 16
subject_1024_beta.safetensors 的 LoRA rank: 16

检查每个文件中所有LoRA权重的rank:
subject.safetensors 中所有LoRA权重的rank都是 4
subject_512.safetensors 中所有LoRA权重的rank都是 16
subject_1024_beta.safetensors 中所有LoRA权重的rank都是 16

示例权重形状:
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([4, 3072])
subject - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 4])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_512 - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight: torch.Size([16, 3072])
subject_1024_beta - transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight: torch.Size([3072, 16])
http://www.dtcms.com/a/108237.html

相关文章:

  • Vue3+Vite+TypeScript+Element Plus开发-03.主页设计与router配置
  • 智能设备运行监控系统
  • intellij Idea 和 dataGrip下载和安装教程
  • 【Nature正刊2023】使用大型语言模型进行自主化学研究
  • 解决小程序video控件在真机和上线后黑屏不播放问题
  • 【ESP32-IDF 笔记】04-I2C配置
  • Scala基础知识5
  • react中hooks使用
  • 关于mysql 数据库中的 慢SQL 的详细分析,包括定义、原因、解决方法及表格总结
  • 【数字化转型,企业应用上云】---持续集成能力重塑企业软件交付新范式
  • 【node-forge】加解密(RSA),代替node-rsa
  • 洛谷题单3-P5721 【深基4.例6】数字直角三角形-python-流程图重构
  • 在Docker中快速部署Redis:从零开始到生产环境配置指南
  • stack栈的基本使用-c++
  • 23种设计模式-结构型模式-享元
  • 在未归一化的线性回归模型中,特征的尺度差异可能导致模型对特征重要性的误判
  • 墨笔 在线Markdown 编辑器
  • VAE讲解
  • PyTorch中卷积层torch.nn.Conv2d
  • Android 切换prefer APN后建立PDN的日志分析
  • ubuntu改用户权限
  • AI调研 | Omnisql模型家族调研与实测
  • ‌Windows 与 Linux网络命令速查表,含常用场景及参数说明
  • 使用高德api实现天气查询
  • 多电机显示并排序
  • WHAT - 如何理解中间件
  • WPF学习路线
  • 关于Gstreamer+MPP硬件加速推流问题:视频输入video0被占用
  • MYSQL实现获取某个经纬度区域内的数据
  • Cesium系列:从入门到实践,打造属于你的3D地球应用