当前位置: 首页 > news >正文

从PyTorch到ONNX:模型部署性能提升

在深度学习模型部署过程中,推理性能优化是一个关键环节。ONNX(Open Neural Network Exchange)作为一个开放的神经网络交换格式,能够在不同框架之间实现模型的无缝转换,并通过优化的运行时环境显著提升推理速度。

环境准备

首先安装必要的依赖包:

# 安装PyTorch生态
pip install torch torchvision torchaudio# 安装ONNX相关包
pip install onnx onnxruntime-gpu  # GPU版本
# pip install onnxruntime          # CPU版本

ONNX常用方法汇总

模型导出 (PyTorch → ONNX)

import torch
import torch.onnx# 加载PyTorch模型
model = torch.load('model.pth')
model.eval()# 创建虚拟输入
dummy_input = torch.randn(1, 3, 224, 224)# 导出ONNX
torch.onnx.export(model,                      # 模型dummy_input,               # 虚拟输入'model.onnx',              # 输出路径input_names=['input'],     # 输入名称output_names=['output'],   # 输出名称dynamic_axes={'input': {0: 'batch_size'}},  # 动态维度opset_version=17
)

模型加载与推理

import onnxruntime as ort
import numpy as np# 加载ONNX模型
session = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])# 获取模型信息
print("输入:", [input.name for input in session.get_inputs()])
print("输出:", [output.name for output in session.get_outputs()])
print("执行提供商:", session.get_providers())# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
input_dict = {session.get_inputs()[0].name: input_data}# 推理
outputs = session.run(None, input_dict)
result = outputs[0]

模型信息查看

import onnx# 加载ONNX模型
model = onnx.load('model.onnx')# 检查模型
onnx.checker.check_model(model)# 打印模型信息
print("Graph inputs:", [input.name for input in model.graph.input])
print("Graph outputs:", [output.name for output in model.graph.output])

性能优化设置

# 会话选项配置
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4        # 线程数
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALLsession = ort.InferenceSession('model.onnx', sess_options)

完整实现代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import onnxruntime as ort
from torchvision.models import resnet50, ResNet50_Weights# ================== 基础配置 ==================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
C, H, W = 3, 224, 224  # 输入图像尺寸
PTH_PATH = "resnet50.pth"
ONNX_PATH = "resnet50.onnx"print(f"使用设备: {device}")
print(f"输入尺寸: {C}x{H}x{W}")# ================== 模型下载与导出 ==================
def download_and_export():"""下载预训练模型并导出为ONNX格式"""print("\n开始下载预训练模型...")# 加载预训练的ResNet50模型pt_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)pt_model.eval()  # 设置为评估模式# 保存PyTorch模型权重torch.save(pt_model.state_dict(), PTH_PATH)print(f"PyTorch模型已保存至: {PTH_PATH}")# 创建虚拟输入用于导出dummy_input = torch.randn(1, C, H, W, device=device)# 导出为ONNX格式print("正在导出ONNX模型...")torch.onnx.export(pt_model,                    # 要导出的模型dummy_input,                 # 虚拟输入张量ONNX_PATH,                   # 输出文件路径input_names=['input'],       # 输入节点名称output_names=['output'],     # 输出节点名称dynamic_axes={               # 动态轴配置"input": {0: "batch_size"},   # batch维度可动态变化"output": {0: "batch_size"},},opset_version=17,            # ONNX算子集版本do_constant_folding=True,    # 常量折叠优化dynamo=False                 # 禁用TorchDynamo)print(f"ONNX模型已导出至: {ONNX_PATH}")# ================== 模型加载 ==================
def load_models():"""加载PyTorch和ONNX模型"""print("\n正在加载模型...")# 加载PyTorch模型model_pt = resnet50(weights=None).to(device)  # 不重新下载权重model_pt.load_state_dict(torch.load(PTH_PATH, map_location=device))model_pt.eval()# 加载ONNX模型providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]sess = ort.InferenceSession(ONNX_PATH, providers=providers)print(f"PyTorch模型已加载")print(f"ONNX模型已加载,使用提供商: {sess.get_providers()}")return model_pt, sess# ================== 性能测试函数 ==================
def benchmark_pytorch(model, data, num_batches=50):"""测试PyTorch模型推理性能"""print(f"\n开始PyTorch性能测试 (批次数: {num_batches})")with torch.no_grad():# 预热阶段for _ in range(10):_ = model(data)# 同步GPU操作if device == 'cuda':torch.cuda.synchronize()# 正式计时start = time.time()for _ in range(num_batches):_ = model(data)if device == 'cuda':torch.cuda.synchronize()end = time.time()avg_time = (end - start) / num_batchesprint(f"PyTorch平均推理时间: {avg_time:.6f}秒")return avg_timedef benchmark_onnx(sess, data, num_batches=50):"""测试ONNX模型推理性能"""print(f"\n开始ONNX性能测试 (批次数: {num_batches})")# 准备ONNX输入格式onnx_input = {sess.get_inputs()[0].name: data.cpu().numpy()}# 预热阶段for _ in range(10):_ = sess.run(None, onnx_input)# 正式计时start = time.time()for _ in range(num_batches):_ = sess.run(None, onnx_input)end = time.time()avg_time = (end - start) / num_batchesprint(f"ONNX平均推理时间: {avg_time:.6f}秒")return avg_time# ================== 精度验证 ==================
def verify_correctness(model_pt, sess):"""验证PyTorch和ONNX模型输出一致性"""print("\n开始精度验证...")# 创建测试输入input_tensor = torch.randn(1, C, H, W, device=device)# PyTorch推理with torch.no_grad():pytorch_out = model_pt(input_tensor).detach().cpu().numpy()# ONNX推理onnx_input = {sess.get_inputs()[0].name: input_tensor.cpu().numpy()}onnx_out = sess.run(None, onnx_input)[0]# 比较结果is_close = np.allclose(pytorch_out, onnx_out, atol=1e-3)max_diff = np.max(np.abs(pytorch_out - onnx_out))print(f"输出一致性检查: {'通过' if is_close else '失败'}")print(f" 最大差异: {max_diff:.6f}")return is_close, max_diff# ================== 主函数 ==================
def main():"""主执行函数"""print("ONNX模型导出与性能对比教程")print("=" * 50)# 1. 下载并导出模型download_and_export()# 2. 加载模型model_pt, sess = load_models()# 3. 准备测试数据batch_size = 64test_data = torch.randn(batch_size, C, H, W, device=device)print(f"\n测试数据形状: {test_data.shape}")# 4. 性能测试print("\n" + "="*30 + " 性能测试 " + "="*30)num_batches = 20pt_time = benchmark_pytorch(model_pt, test_data, num_batches)onnx_time = benchmark_onnx(sess, test_data, num_batches)# 5. 结果分析print("\n" + "="*30 + " 测试结果 " + "="*30)print(f"PyTorch推理时间:  {pt_time:.6f}秒")print(f"ONNX推理时间:     {onnx_time:.6f}秒")if onnx_time < pt_time:speedup = pt_time / onnx_timeprint(f"ONNX加速比:       {speedup:.2f}x")print(f"性能提升:         {((pt_time - onnx_time) / pt_time * 100):.1f}%")else:slowdown = onnx_time / pt_timeprint(f"ONNX较慢:         {slowdown:.2f}x")# 6. 精度验证print("\n" + "="*30 + " 精度验证 " + "="*30)is_accurate, max_diff = verify_correctness(model_pt, sess)if __name__ == "__main__":main()

文章转载自:

http://sAsZ3a0x.psdsk.cn
http://e9m1gixu.psdsk.cn
http://24Q4GnGP.psdsk.cn
http://1jyfpYZJ.psdsk.cn
http://w1oHLx71.psdsk.cn
http://1OVDldne.psdsk.cn
http://nEMZODKD.psdsk.cn
http://d79kRGIT.psdsk.cn
http://6cGplDCx.psdsk.cn
http://pTpbPzlh.psdsk.cn
http://e7VkjuwZ.psdsk.cn
http://rxawJBYH.psdsk.cn
http://TnJ391Q5.psdsk.cn
http://bTXnoMod.psdsk.cn
http://zgqBUi3z.psdsk.cn
http://qTRiksVH.psdsk.cn
http://rkMw51G2.psdsk.cn
http://aeTrw8VH.psdsk.cn
http://kxM3JBTy.psdsk.cn
http://rITn6Nqs.psdsk.cn
http://zzjJiFDZ.psdsk.cn
http://Yx6EzEVW.psdsk.cn
http://t1OxmBpm.psdsk.cn
http://ll7Bfkbv.psdsk.cn
http://ftn8DrWX.psdsk.cn
http://kqby97OJ.psdsk.cn
http://4O6Ap443.psdsk.cn
http://iOeuiLOh.psdsk.cn
http://H5AX7mIp.psdsk.cn
http://Ts1uFbft.psdsk.cn
http://www.dtcms.com/a/375351.html

相关文章:

  • JAVA:实现快速排序算法的技术指南
  • SQL 触发器从入门到进阶:原理、时机、实战与避坑指南
  • 无标记点动捕技术:重塑展厅展馆的沉浸式数字交互新时代
  • 【Agent】DeerFlow Planner:执行流程与架构设计(基于真实 Trace 深度解析)
  • R语言读取excel文件数据-解决na问题
  • 在钉钉上长出的AI组织:森马的路径与启示
  • IntelliJ IDEA 中 JVM 配置参考
  • JVM(二)--- 类加载子系统
  • 9.ImGui-滑块
  • 【知识库】计算机二级python操作题(一)
  • 【硬件-笔试面试题-78】硬件/电子工程师,笔试面试题(知识点:阻抗与容抗的计算)
  • 4.5Vue的列表渲染
  • 使用YOLO11进行路面裂缝检测
  • 常见并行概念解析
  • 9月9日
  • centos系统上部署安装minio
  • 下载CentOS 7——从阿里云上下载不同版本的 CentOS 7
  • 《预约一团乱麻?预约任务看板让你告别排班噩梦!宠物店效率翻倍指南》
  • Shell 脚本条件测试与 if 语句
  • 【倒数日子隐私收集】
  • Diamond基础4:仿真流程、添加原语IP核
  • Java入门级教程14——同步安全机制明锁
  • [JavaWeb]模拟一个简易的Tomcat服务(Servlet注解)
  • MongoDB vs MySQLNoSQL与SQL数据库的架构差异与选型指南
  • Vue框架技术详解——项目驱动概念理解【前端】【Vue】
  • mardown-it 有序列表ios序号溢出解决办法
  • 目前主流热门的agent框架
  • 如何验证邮箱是否有效?常见方法与工具推荐
  • Python 类型注释核心知识点:变量、函数 / 方法与 Union 类型分步解析
  • 端口转发实操