当前位置: 首页 > news >正文

机器学习模型在C++平台的部署

一、概述

  机器学习模型的训练通常在Python环境下完成,而现实生产环境的复杂性和多样性使得模型的部署成为一个值得关注的重点。不同应用场景下有不同适应的实现方式,这里主要介绍通过一种通用中间格式——ONNX(Open Neural Network Exchange),来实现机器学习模型在C++平台的部署。

二、步骤

  s1. Python环境中安装onnxruntime、skl2onnx工具模块;

  s2. Python环境中训练机器学习模型;

  s3. 将训练好的模型保存为.onnx格式的模型文件;

  s4. C++环境中安装Microsoft.ML.OnnxRuntime程序包;
(Visual Studio 2022中可通过项目->管理NuGet程序包完成快捷安装)

  S5. C++环境中加载模型文件,完成功能开发。

三、示例

  使用 Python 训练一个线性回归模型并将其导出为 ONNX 格式的文件,在C++环境下完成对模型的部署和推理。

1.Python训练和导出

(环境:Python 3.11,scikit-learn 1.6.1,onnxruntime 1.22.0,skl2onnx 1.19.1)

import numpy as np
import onnxruntime as ort
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType# 生成示例数据
X, y = make_regression(n_samples=100, n_features=5, random_state=42)# 训练线性回归模型
model = LinearRegression()
model.fit(X, y)# 定义输入格式
initial_type = [('input', FloatTensorType([None, 5]))]# 转换模型为 ONNX 格式
onnx_model = convert_sklearn(model, initial_types=initial_type)# 保存 ONNX 模型
with open("linear_regression.onnx", "wb") as f:f.write(onnx_model.SerializeToString())print("\n模型已保存为: linear_regression.onnx\n")# 测试导出的模型
ort_session = ort.InferenceSession("linear_regression.onnx")
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name# 创建一个测试样本
test_input = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape(1,5).astype(np.float32)# 运行推理
results = ort_session.run([output_name], {input_name: test_input})print(f"测试输入: {test_input}")
print(f"预测结果: {results[0]}")

在这里插入图片描述

2. C++ 部署和推理

(环境:C++ 14,Microsoft.ML.OnnxRuntime 1.22.0)

#include <iostream>
#include <vector>
#include <string>
#include <memory>
#include <onnxruntime_cxx_api.h>int main() {// 初始化环境Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXExample");// 初始化会话选项Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(1);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);// 加载模型std::wstring model_path = L"linear_regression.onnx";Ort::Session session(env, model_path.c_str(), session_options);// 获取输入信息Ort::AllocatorWithDefaultOptions allocator;size_t num_inputs = session.GetInputCount();size_t num_outputs = session.GetOutputCount();// 假设只有一个输入和一个输出if (num_inputs != 1 || num_outputs != 1) {std::cerr << "模型必须有且仅有一个输入和一个输出" << std::endl;return 1;}// 获取输入名称、类型和形状std::string input_name = session.GetInputNameAllocated(0, allocator).get();Ort::TypeInfo input_type_info = session.GetInputTypeInfo(0);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();ONNXTensorElementDataType input_type = input_tensor_info.GetElementType();std::vector<int64_t> input_dims = input_tensor_info.GetShape();// 获取输出名称std::string output_name = session.GetOutputNameAllocated(0, allocator).get();// 创建输入数据std::vector<float> input_data = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f };size_t input_size = 5;// 创建输入张量std::vector<int64_t> input_shape = { 1, static_cast<int64_t>(input_size) };auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data.data(),input_data.size(), input_shape.data(), 2);// 验证输入张量是否为张量if (!input_tensor.IsTensor()) {std::cerr << "创建的输入不是张量类型" << std::endl;return 1;}// 运行模型std::vector<const char*> input_names = { input_name.c_str() };std::vector<const char*> output_names = { output_name.c_str() };std::vector<Ort::Value> outputs = session.Run(Ort::RunOptions{ nullptr },input_names.data(),&input_tensor,1,output_names.data(),1);// 获取输出结果float* output_data = outputs[0].GetTensorMutableData<float>();Ort::TensorTypeAndShapeInfo output_info = outputs[0].GetTensorTypeAndShapeInfo();std::vector<int64_t> output_dims = output_info.GetShape();// 输出结果std::cout << "输入数据: ";for (float val : input_data) {std::cout << val << " ";}std::cout << std::endl;std::cout << "预测结果: ";for (size_t i = 0; i < output_info.GetElementCount(); ++i) {std::cout << output_data[i] << " ";}std::cout << std::endl;return 0;
}

在这里插入图片描述



End.

http://www.dtcms.com/a/271046.html

相关文章:

  • 基于 Redis 实现高并发滑动窗口限流:Java实战与深度解析
  • 开始读 PostgreSQL 16 Administration Cookbook
  • 深度学习 最简单的神经网络 线性回归网络
  • ArtifactsBench: 弥合LLM 代码生成评估中的视觉交互差距
  • 论文解析篇 | YOLOv12:以注意力机制为核心的实时目标检测算法
  • 腾讯云COS,阿里云OSS对象存储服务-删除操作的响应码204
  • 汽车智能化2.0引爆「万亿蛋糕」,谁在改写游戏规则?
  • 通用游戏前端架构设计思考
  • VSCode配置Cline插件调用MCP服务实现任务自动化
  • 旅游管理实训室建设的关键要点探讨
  • 向量空间 线性代数
  • 软件测试偏技术方向学习路线是怎样的?
  • 安装nvm管理node.js,详细安装使用教程和详细命令
  • Spring Boot微服务中集成gRPC实践经验分享
  • 【每日算法】专题六_模拟
  • 全球发展币GDEV:从中国出发,走向全球的数字发展合作蓝图
  • 2 STM32单片机-蜂鸣器驱动
  • 【vLLM 学习】Eagle
  • oracle ocp题库有多少道题,以及题库背诵技巧
  • Context Engineering:从Prompt Engineering到上下文工程的演进
  • 破局电机制造四大痛点:MES与AI视觉的协同智造实践
  • 基于SD-WAN的管件制造数字化产线系统集成方案
  • 中山排气歧管批量自动化智能化3D尺寸测量及cav检测分析
  • 什么是幂等
  • clickhouse 各个引擎适用的场景
  • 飞算 JavaAI 智能编程助手 - 重塑编程新模态
  • ClickHouse 时间范围查询:精准筛选「本月数据」
  • tinyxml2 开源库与 VS2010 结合使用
  • LaCo: Large Language Model Pruning via Layer Collapse
  • Spring Boot 扩展点深度解析:设计思想、实现细节与最佳实践