当前位置: 首页 > news >正文

PyTorch模型转ONNX例子

 参考:(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.6.0+cu124 documentation

import numpy as np
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init
import onnx
import onnxruntime
import time
import os
from PIL import Image
import torchvision.transforms as transforms

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def evaluation_accuracy(x, torch_model, ort_session):
    torch_out = torch_model(x)
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")

def evaluation_speed(x, torch_model, ort_session):
    start = time.time()
    torch_out = torch_model(x)
    end = time.time()
    print(f"Inference of Pytorch model used {end - start} seconds")
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    start = time.time()
    ort_outs = ort_session.run(None, ort_inputs)
    end = time.time()
    print(f"Inference of ONNX model used {end - start} seconds")

def evaluation_result(ort_session):
    img = Image.open("cat.jpg")
    resize = transforms.Resize([224, 224])
    img = resize(img)
    img_ycbcr = img.convert('YCbCr')
    img_y, img_cb, img_cr = img_ycbcr.split()
    to_tensor = transforms.ToTensor()
    img_y = to_tensor(img_y)
    img_y.unsqueeze_(0)

    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
    ort_outs = ort_session.run(None, ort_inputs)
    img_out_y = ort_outs[0]

    img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
    final_img = Image.merge(
        "YCbCr", [
            img_out_y,
            img_cb.resize(img_out_y.size, Image.BICUBIC),
            img_cr.resize(img_out_y.size, Image.BICUBIC),
        ]).convert("RGB")
    final_img.save("cat_superres_with_ort.jpg")

    img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
    img.save("cat_resized.jpg")


if __name__ == '__main__':
    torch_model = SuperResolutionNet(upscale_factor=3)
    model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
    batch_size = 64
    map_location = lambda storage, loc: storage
    if torch.cuda.is_available():
        map_location = None
    torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

    torch_model.eval()

    x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
    if not os.path.exists( "super_resolution.onnx"):
        torch.onnx.export(torch_model,             # model being run
                        x,                         # model input (or a tuple for multiple inputs)
                        "super_resolution.onnx",   # where to save the model (can be a file or file-like object)
                        export_params=True,        # store the trained parameter weights inside the model file
                        opset_version=10,          # the ONNX version to export the model to
                        do_constant_folding=True,  # whether to execute constant folding for optimization
                        input_names = ['input'],   # the model's input names
                        output_names = ['output'], # the model's output names
                        dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                        'output' : {0 : 'batch_size'}})

    onnx_model = onnx.load("super_resolution.onnx")
    onnx.checker.check_model(onnx_model)

    ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
    evaluation_accuracy(x, torch_model, ort_session)
    evaluation_speed(x, torch_model, ort_session)
    evaluation_result(ort_session)

http://www.dtcms.com/a/80348.html

相关文章:

  • 深入探究 JVM 堆的垃圾回收机制(一)— 判活
  • python3 -m http.sever 8080加载不了解决办法
  • 6个常见的Python设计模式及应用场景
  • Python实战:开发经典猜拳游戏(石头剪刀布)
  • MySQL事务全解析:从概念到实战
  • 【CXX-Qt】2.1.1 为 WebAssembly 构建
  • 汽车免拆诊断案例 | 2024 款路虎发现运动版车无法正常识别智能钥匙
  • Java EE 进阶:MyBatis
  • 【NLP】 11. 神经网络,线性模型,非线性模型,激活函数,感知器优化,正则化学习方法
  • SpringBoot配置文件加载优先级
  • 最大公约数(GCD)和最小公倍数(LCM)专题练习 ——基于罗勇军老师的《蓝桥杯算法入门C/C++》
  • 蓝桥杯2023年第十四届省赛真题-接龙数列
  • Linux后门程序工作原理的详细解释,以及相应的防御措施
  • c语言数据结构 双循环链表设计(完整代码)
  • Ubuntu版免翻墙搭建BatteryHistorian
  • freeswitch(开启抓包信息)
  • 观察RenderDoc截帧UE时“Event”代表什么
  • ssh 多重验证的好处:降低密钥长度,动态密码
  • 分布式任务调度
  • 事件响应计划:网络弹性的关键
  • C++ :try 语句块和异常处理
  • IDEA批量替换项目下所有文件中的特定内容
  • Python Cookbook-4.7 在行列表中完成对列的删除和排序
  • 主流加固方案深度剖析(梆梆/腾讯/阿里)
  • 《数据库原理教程》—— 第三章 关系数据模型 笔记
  • 解释 RESTful API,以及如何使用它构建 web 应用程序
  • Linux驱动开发实战(七):pinctrl引脚管理入门结合数据手册分析
  • Powershell WSL导出导入ubuntu22.04.5子系统
  • 1.5.5 掌握Scala内建控制结构 - 异常处理
  • 编写脚本在Linux下启动、停止SpringBoot工程