当前位置: 首页 > news >正文

stable diffusion 量化加速点

文章目录

    • 一、导出为dynamic shape
      • 1)函数讲解(函数导出、输出检查)
      • 2)代码展示
    • 二、导出为static shape
      • 1)函数讲解(略)
      • 2)代码展示
    • 三、序列化为FP32测速
      • 1)测速
      • 2)代码
    • 四、序列化为FP16测速
      • 1)测速
      • 2)代码同上
    • 五、发现并解决解决CLIP FP16溢出,并测速
      • 1)如何找到溢出的算子
      • 2)CLIP溢出算子解决方案
      • 3)其他FP16算子溢出的解决方案
    • 六、cuda-graph代码优化并测速
    • 七、图片迭代次数优化PD、合并GroupNorm算子制作plugin,UNet和ControlNet拼batch测试
      • 1)迭代次数优化
      • 2)合并GroupNorm算子
      • 3)UNet和ControlNet拼batch
    • 八、根据smooth-quant算法优化INT8量化,对比测速PD
      • 1)smooth-quant算法原理
      • 2)smooth-quant算法代码
      • 3)测速PD损失

一、导出为dynamic shape

1)函数讲解(函数导出、输出检查)

①torch.onnx.export

    torch.onnx.export(
        clip_model,
        (tokens),
        onnx_path,
        verbose=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )
(1)export_params:默认为true,表示导出的 ONNX 模型文件会包含模型的所有参数(如权重、偏置等)。而当设置为 False 时,导出的 ONNX 模型文件仅包含模型的计算图结构,不包含模型的参数。这意味着导出的 ONNX 文件会小很多,因为它没有存储大量的参数数据
(2)verbose:为true表示,将会输出大量打印日志信息
(3)do_constant_folding:一般为true,是一个布尔类型的参数,其作用是控制在导出 ONNX 模型时是否进行常量折叠优化从而提高推理性能。为TRUE开启常量折叠优化。在导出 ONNX 模型时,会对图中所有仅包含常量输入的操作进行预先计算,并用计算结果替换这些操作,以此简化计算图,减少模型的计算量和复杂度。
(4)input_names和output_names:输入、输出参数
(5)dynamic_axes:是一个字典,其键为输入或输出张量的名称,值也是一个字典,用于指定该张量中哪些维度是动态的。内层字典的键是维度索引(从 0 开始),值是一个字符串,用于标识这个动态维度,通常在 ONNX 运行时会使用这个标识来指定具体的维度大小
(6)opset_version:指定optset的版本

输入参数举例:
    dynamic_axes = {
   
        "x": {
   0: "batch_size"},
        "hint": {
   0: "batch_size"},
        "timesteps": {
   0: "batch_size"},
        "context": {
   0: "batch_size", 1: "sequence_length"},
        "output": {
   0: "batch_size", 1: "hint_height", 2: "hint_width"}
    }
    
	dynamic_axes = {
   "input_ids": {
   1: "S"}, "last_hidden_state": {
   1: "S"}}
    
        dynamic_axes = {
   
        "x": {
   0: "latent"},
    }

②误差检查

#onnx_path onnx文件目录
#input_dicts  输入参数
#torch_outputs  模型输出结果
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):
    onnx_model = onnx.load(onnx_path)
    # onnx.checker.check_model(onnx_model)
    sess = rt.InferenceSession(onnx_path)
    # outputs = self.get_output_names()
    # latent input
    # data = np.zeros((4, 77), dtype=np.int32)
    result = sess.run(None, input_dicts)
    cnt = 0
    for i in range(0, len(torch_outputs)):
        ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)
        cnt = cnt +1
        if ret is False:
            #print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")
            print("Error onnxruntime_check")
            # import pdb; pdb.set_trace()
        #print("cnt:", cnt)

2)代码展示

  • 代码
import numpy as np
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3
import cv2
import datetime
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import os

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from onnx import shape_inference
import onnx_graphsurgeon as gs
import onnx
import onnxruntime as rt

def optimize(onnx_path, opt_onnx_path):
    from onnxsim import simplify
    model = onnx.load(onnx_path)
    graph = gs.import_onnx(model)
    print(f"{
     onnx_path} simplify start !")
    # self.info("init", graph)
    model_simp, check = simplify(model)
    # self.info("opt", gs.import_onnx(model_simp))
    onnx.save(model_simp, opt_onnx_path, save_as_external_data=True)
    assert check, "Simplified ONNX model could not be validated"
    print(f"{
     onnx_path} simplify done !")

def onnxruntime_check(onnx_path, input_dicts, torch_outputs):
    onnx_model = onnx.load(onnx_path)
    # onnx.checker.check_model(onnx_model)
    sess = rt.InferenceSession(onnx_path)
    # outputs = self.get_output_names()
    # latent input
    # data = np.zeros((4, 77), dtype=np.int32)
    result = sess.run(None, input_dicts)
    cnt = 0
    for i in range(0, len(torch_outputs)):
        ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)
        cnt = cnt +1
        if ret is False:
            #print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")
            print("Error onnxruntime_check")
            # import pdb; pdb.set_trace()
        #print("cnt:", cnt)
    


class hackathon():
    def initialize(self):
        self.apply_canny = CannyDetector()
        self.model = create_model('./models/cldm_v15.yaml').cpu()
        self.model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))
        # self.model.load_state_dict(load_state_dict('/home/player/ControlNet/models/control_sd15_canny.pth', location='cuda'))
        self.model = self.model.cpu()
        self.model.eval()
        self.ddim_sampler = DDIMSampler(self.model)

hk = hackathon()
hk.initialize()

def export_clip_model():
    clip_model = hk.model.cond_stage_model

    import types

    def forward(self, tokens):
        outputs = self.transformer(
            input_ids=tokens, output_hidden_states=self.layer == "hidden"
        )
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    clip_model.forward = types.MethodType(forward, clip_model)

    onnx_path = "./onnx/CLIP.onnx"

    tokens = torch.zeros(1, 77, dtype=torch.int32)
    input_names = ["input_ids"]
    output_names = ["last_hidden_state"]
    dynamic_axes = {
   "input_ids": {
   1: "S"}, "last_hidden_state": {
   1: "S"}}

    torch.onnx.export(
        clip_model,
        (tokens),
        onnx_path,
        verbose=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )
    print("======================= CLIP model export onnx done!")

    # verify onnx model
    output = clip_model(tokens)
    input_dicts = {
   "input_ids": tokens.numpy()}
    onnxruntime_check(onnx_path, input_dicts, [output])
    print("======================= CLIP onnx model verify done!")

    # opt_onnx_path = "./onnx/CLIP.opt.onnx"
    # optimize(onnx_path, opt_onnx_path)

def export_control_net_model():
    control_net_model = hk.model.control_model
    onnx_path = "./onnx/control_net_model.onnx"

    def get_shape(B=1,S=64):
        return [(B, 4, 32, 48),(B, 3, 256, 384),tuple([B])

相关文章:

  • 2025-04-06 Unity Editor 2 —— GUILayout
  • MySQL【sql之DML】
  • mac安装低版本node
  • 使用注解开发springMVC
  • 华东师范​地面机器人融合空中无人机视角的具身导航!KiteRunner:语言驱动的户外环境合作式局部-全局导航策略
  • 结构化数据库和非结构化数据库的区别是什么
  • 轨迹速度聚类 实战 速度平滑
  • 大模型(二)神经网络
  • Autosar应用层开发基础——Arxml制作
  • LeetCode --- 443周赛
  • 08、Docker学习,常用安装:ClickHouse
  • leetcode122-买卖股票的最佳时机II
  • 通过ssh config让远程服务器通过本地代理访问受限网络
  • 公司内网部署离线deepseek本地模型实战
  • 快 速 幂
  • MySQL请求处理全流程深度解析:从SQL语句到数据返回
  • 队列(C/C++)
  • 25/4/6 算法笔记<仿真O2DES>基础知识学习
  • CasaOS小主机本地安装1Panel运维面板结合内网穿透移动端远程运维
  • 【网络安全】大学信息安全技术 期末考试复习题
  • 怎样做分类网站/简单的网页设计作品
  • 免费做app和网站的平台有哪些/西安seo和网络推广
  • 亚马逊国际站官网/谷歌网站推广
  • 教育发展基金会网站建设/百度指数热度榜
  • 怎样创建自己的网站/广州百度网站推广
  • 手机网站无法访问的解决方法/steam交易链接在哪复制