当前位置: 首页 > news >正文

pytorch自定义算子转tensorrt

pytorch导出onnx模型

设有模型:

import torch
import torch.nn as nnclass MYPLUGINImpl(torch.autograd.Function):@staticmethoddef symbolic(g, x, p):return g.op("MYPLUGIN", x, p,g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.int32)),attr1_s="这是字符串属性",attr2_i=[1, 2],attr3_f=222)@staticmethoddef forward(ctx, x, a):return x + aclass MYPLUGIN(nn.Module):def __init__(self, n):super().__init__()self.param = nn.parameter.Parameter(torch.arange(n).float())def forward(self, x, a):return MYPLUGINImpl.apply(x, a)class Model(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 1, 3, padding=1)self.conv.weight.data.fill_(1)self.conv.bias.data.fill_(0)self.myplugin = MYPLUGIN(3)def forward(self, x, a):x = self.conv(x)x = self.myplugin(x, a)return xmodel = Model().eval()
input = torch.tensor([[[1, 1, 1],[1, 1, 1],[1, 1, 1],],
], dtype=torch.float32).view(1, 1, 3, 3)
a = torch.tensor(2, dtype=torch.int32)
output = model(input, a)
print(f"inference output = {output}")torch.onnx.export(model, # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号(input, a),"myplugin.onnx", # 储存的文件路径input_names=["image"], # 为输入和输出节点指定名称,方便后面查看或者操作output_names=["output"],opset_version=13,# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,#verbose=True,# 打印详细信息# 表示有batch、height、width3个维度是动态的,在onnx中给其赋值为-1,通常,我们只设置batch为动态,其他的避免动态# dynamic_axes={        #     "image": {#         0: "batch", 2: "height", 3: "width"},#     "output": {#         0: "batch", 2: "height", 3: "width"},# },
)

程序输出:

inference output = 
tensor([[[[ 6.,  8.,  6.],[ 8., 11.,  8.],[ 6.,  8.,  6.]]]], grad_fn=<MYPLUGINImplBackward>)

导出的onnx模型结构如下:在这里插入图片描述
在这里插入图片描述

编写tensorrt自定义算子

采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建myPlugin文件夹,添加下面文件:
myPlugin.h

#ifndef TRT_MYPLUGIN_H
#define TRT_MYPLUGIN_H#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include <string>
#include <vector>namespace nvinfer1
{
namespace plugin
{
class myPlugin : public nvinfer1::IPluginV2DynamicExt 
{          
public:myPlugin(const std::string name, const std::string attr1, float attr3);  // 接受算子名称属性,build engine时构造函数myPlugin(const std::string name, const void* data, size_t length);  // 接受算子名称和反序列化的engine data,推理时构造函数int getNbOutputs() const noexcept override;virtual nvinfer1::DataType getOutputDataType(int32_t index,nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {           return inputTypes[0];}virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;int initialize() noexcept override;void terminate() noexcept override;virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,int32_t nbOutputs) const noexcept override {           return 0;};int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;size_t getSerializationSize() const noexcept override;void serialize(void* buffer)  const noexcept override;virtual void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,int32_t nbOutputs) noexcept override;const char* getPluginType() const noexcept override;const char* getPluginVersion() const noexcept override;void destroy() noexcept override;nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;const char* getPluginNamespace()const noexcept override;private:const std::string mLayerName;std::string mattr1;float mattr3;size_t mInputVolume;std::string mNamespace;
};class myPluginCreator : public nvinfer1::IPluginCreator 
{           
public:myPluginCreator();const char* getPluginName() const noexcept override;const char* getPluginVersion() const noexcept override;const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,nvinfer1::PluginFieldCollection const* fc) noexcept override;nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,void const* serialData, size_t serialLength)noexcept override;void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;const char* getPluginNamespace() const noexcept override;
private:static nvinfer1::PluginFieldCollection mfc;static std::vector<nvinfer1::PluginField> mPluginAttributes;std::string mNamespace;
};
}
}#endif // TRT_MYPLUGIN_H

myPlugin.cpp

#include "myPlugin.h"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::myPlugin;
using nvinfer1::plugin::myPluginCreator;namespace nvinfer1
{
namespace plugin
{
void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream);// 静态类字段的初始化
nvinfer1::PluginFieldCollection myPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> myPluginCreator::mPluginAttributes;// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val) 
{          *reinterpret_cast<T*>(buffer) = val;buffer += sizeof(T);
}// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(char const*& buffer) 
{          T val = *reinterpret_cast<const T*>(buffer);buffer += sizeof(T);return val;
}// 定义插件类MYPlugin
myPlugin::myPlugin(const std::string name, const std::string attr1, float attr3):mLayerName(name), mattr1(attr1), mattr3(attr3)
{std::cout<<"myPlugin"<<std::endl;
};myPlugin::myPlugin(const std::string name, const void* data, size_t length) :mLayerName(name)
{std::cout<<"myPlugin()"<<std::endl;       // Deserialize in the same order as serializationchar const* d = static_cast<char const*>(data);char const* a = d;int nstr = readFromBuffer<int>(d);mattr1 = std::string(d, d + nstr);d += nstr;mattr3 = readFromBuffer<float>(d);assert(d == (a + length));
};char const* myPlugin::getPluginType() const noexcept
{   std::cout<<"getPluginType"<<std::endl;         return "MYPLUGIN";
}char const* myPlugin::getPluginVersion() const noexcept
{std::cout<<"getPluginVersion"<<std::endl;             return "1";
}int myPlugin::getNbOutputs() const noexcept 
{std::cout<<"getNbOutputs"<<std::endl;          return 1;
}// 获取该层的输出维度是多少
nvinfer1::DimsExprs myPlugin::getOutputDimensions(int32_t outputIndex,const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{std::cout<<"getOutputDimensions"<<std::endl;         //不改变输入尺寸,所以输出尺寸将与输入尺寸相同return inputs[0];
}int myPlugin::initialize() noexcept
{std::cout<<"initialize"<<std::endl;          return 0;
}int myPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{std::cout<<"enqueue"<<std::endl;   void* output = outputs[0];size_t volume = 1;for (int i = 0; i < inputDesc->dims.nbDims; ++i){            volume *= inputDesc->dims.d[i];}mInputVolume = volume;myselu_inference(static_cast<const float*>(inputs[0]),static_cast<const int*>(inputs[1]),static_cast<float*>(output),mInputVolume,stream);return 0;
}size_t myPlugin::getSerializationSize() const noexcept
{std::cout<<"getSerializationSize"<<std::endl;               return sizeof(int) + mattr1.size() + sizeof(mattr3);
}// 该层的参数序列化储存为trtmodel文件
void myPlugin::serialize(void* buffer)  const noexcept
{std::cout<<"serialize"<<std::endl;                  char* d = static_cast<char*>(buffer);char const* a = d;int nstr = mattr1.size();writeToBuffer(d, nstr);memcpy(d, mattr1.data(), nstr);d += nstr;writeToBuffer(d, mattr3);assert(d == a + getSerializationSize());
}// 判断该插件所支持的数据格式和类型
bool myPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,int32_t nbOutputs) noexcept
{std::cout<<"supportsFormatCombination"<<std::endl;     PLUGIN_ASSERT(pos < nbInputs + nbOutputs);if (pos == 0){return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 1){return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 2){return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}return true;   
}void myPlugin::terminate() noexcept { }void myPlugin::destroy() noexcept
{            // This gets called when the network containing plugin is destroyeddelete this;
}// 配置插件格式:目前这个层所采用的数据格式和类型
void myPlugin::configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{std::cout<<"configurePlugin"<<std::endl;           try {PLUGIN_ASSERT(nbInputs == 3 && nbOutputs == 1); // 确认3个输入和1个输出} catch (std::exception const& e) {caughtError(e);}
}// 克隆插件
nvinfer1::IPluginV2DynamicExt* myPlugin::clone() const noexcept
{std::cout<<"clone"<<std::endl;      auto plugin = new myPlugin(mLayerName, mattr1, mattr3);plugin->setPluginNamespace(mNamespace.c_str());return plugin;
}void myPlugin::setPluginNamespace(char const* libNamespace) noexcept
{std::cout<<"setPluginNamespace"<<std::endl;                 mNamespace = libNamespace;
}char const* myPlugin::getPluginNamespace() const noexcept
{std::cout<<"getPluginNamespace"<<std::endl;               return mNamespace.c_str();
}// 插件创建器
myPluginCreator::myPluginCreator()
{std::cout<<"myPluginCreator"<<std::endl;               // 描述myPlugin的必要PluginField参数mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));// 收集PluginField的参数mfc.nbFields = mPluginAttributes.size();mfc.fields = mPluginAttributes.data();
}char const* myPluginCreator::getPluginName() const noexcept
{ std::cout<<"getPluginName"<<std::endl;             return "MYPLUGIN";
}char const* myPluginCreator::getPluginVersion() const noexcept
{std::cout<<"getPluginVersion"<<std::endl;                return "1";
}const nvinfer1::PluginFieldCollection* myPluginCreator::getFieldNames() noexcept
{std::cout<<"getFieldNames"<<std::endl;               return &mfc;
}// 创建plugin
nvinfer1::IPluginV2* myPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,nvinfer1::PluginFieldCollection const* fc) noexcept
{std::cout<<"createPlugin"<<std::endl;               std::string attr1;float attr3;const nvinfer1::PluginField* fields = fc->fields;// Parse fields from PluginFieldCollectionfor (int i = 0; i < fc->nbFields; ++i){          if (strcmp(fields[i].name, "attr1")==0) {       assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);auto cp = static_cast<char const*>(fields[i].data);attr1 = std::string(cp, cp + fields[i].length);}else if (strcmp(fields[i].name, "attr3") == 0) {        assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);attr3 = *(static_cast<const float*>(fields[i].data));}}return new myPlugin(name, attr1, attr3);
}// 反序列化插件参数进行创建
nvinfer1::IPluginV2* myPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,void const* serialData, size_t serialLength)noexcept
{   std::cout<<"deserializePlugin"<<std::endl;      // This object will be deleted when the network is destroyed, which will// call myPlugin::destroy()return new myPlugin(name, serialData, serialLength);
}void myPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{  std::cout<<"setPluginNamespace"<<std::endl;        mNamespace = libNamespace;
}char const* myPluginCreator::getPluginNamespace() const noexcept
{    std::cout<<"getPluginNamespace"<<std::endl;      return mNamespace.c_str();
}
}
}

myPlugin.cu

#include "NvInfer.h"
#include <cuda_runtime.h>namespace nvinfer1
{
namespace plugin
{
static __global__ void myselu_kernel(const float* x, const int* a, float* output, int n)
{int position = threadIdx.x + blockDim.x*blockIdx.x;if (position >= n) return;output[position] = x[position] + a[0];
}void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream)
{     const int nthreads = 512;int block_size = n > nthreads ? nthreads : n;int grid_size = (n + block_size - 1) / block_size;myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, a, output, n);
}
}
}

CMakeLists.txt

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

在TensorRT/plugin/inferPlugin.cpp的开头添加

#include "myPlugin/myPlugin.h"

并在initLibNvInferPlugins函数中添加

initializePlugin<nvinfer1::plugin::myPluginCreator>(logger, libNamespace);

在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加

myPlugin

在TensorRT/CMakeLists.txt中设置TRT_LIB_DIRTRT_OUT_DIR,再重新编译tensorrt。

tensorrt推理测试

运行下面的命令把onnx 转为engine模型:

TensorRT-10.6.0.26/bin/trtexec --onnx=myplugin.onnx --saveEngine=myplugin.engine

编写python推理脚本:

import numpy as np
import tensorrt as trt
import commonlogger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("myplugin.engine", "rb") as f, trt.Runtime(logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)input = np.ones((3, 3))
a = 2
np.copyto(inputs[0].host, input.ravel())
np.copyto(inputs[1].host, a)output = common.do_inference(context,engine=engine, bindings=bindings,inputs=inputs, outputs=outputs, stream=stream)
print(output)  

文章转载自:

http://JAlpqKsg.brscd.cn
http://kyfGHMyi.brscd.cn
http://QJiigu4J.brscd.cn
http://EX0ii5K7.brscd.cn
http://L6fwQgwz.brscd.cn
http://m3knzjyU.brscd.cn
http://x06GlIAh.brscd.cn
http://gOfId2Dm.brscd.cn
http://NqAnuZie.brscd.cn
http://H06GQAN4.brscd.cn
http://gg81VzuH.brscd.cn
http://YLhTVX1N.brscd.cn
http://ptRYm4RR.brscd.cn
http://lcCJwE6I.brscd.cn
http://0qvH7S3e.brscd.cn
http://BTz7FBvk.brscd.cn
http://H2XEB6jT.brscd.cn
http://SrOoetcz.brscd.cn
http://6bZEPYpE.brscd.cn
http://7tiywpRY.brscd.cn
http://jWkZAGw3.brscd.cn
http://a91ydKGR.brscd.cn
http://v28apUV1.brscd.cn
http://edbm36DY.brscd.cn
http://q4Nb6nAG.brscd.cn
http://6JMwW0Gy.brscd.cn
http://Ah60Fmj1.brscd.cn
http://wZDS322z.brscd.cn
http://k8x1gpvT.brscd.cn
http://iWjoRG8I.brscd.cn
http://www.dtcms.com/a/387634.html

相关文章:

  • Springboots上传文件的同时传递参数用对象接收
  • Next.js 中表单处理与校验:React Hook Form 实战
  • 国标GB28181视频平台EasyGBS如何解决安防视频融合与级联管理的核心痛点?
  • Web 页面 SEO 审计自动化 - 基于 n8n 和 Firecrawl
  • arcgis文件导出显示导出对象错误
  • PPT中将图片按比例裁剪
  • React + Zustand 状态管理
  • 复位开关芯片 EY412-A07E50国产低功耗延时芯片方案超低功耗
  • 动态规划-详解回文串系列问题
  • C语言基础学习(五)——进制
  • 如何在C#中将 Excel 文件(XLS/XLSX)转换为 PDF
  • 【Error】django-debug-toolbar不显示:Failed to load module script
  • Windows 版本 WDK 版本 Windows SDK Visual Studio各版本对应关系
  • WPF 快速布局技巧
  • K8S YAML 功能详解:让容器配置更灵活
  • CAD迷你看图下载安装教程(2025最新版)
  • 根据文本区域`textarea`的内容调整大小`field-sizing:content`
  • avcodec_send_packet闪退问题
  • ftrace的trace_marker使用
  • ★基于FPGA的通信基础链路开发项目汇集目录
  • SpringBoot中@Value注入失败问题解决
  • DotCore进程CPU飙高跟踪处理方案
  • PantherX2黑豹X2 armbian 编译rkmpp ffmpeg 实现CPU视频转码
  • 2、Logstash与FileBeat详解以及ELK整合详解(Logstash安装及简单实战使用)
  • ENVI系列教程(六)——自动采集控制点的 RPC 正射校正
  • 多可见光线索引导的热红外无人机图像超分辨率重建
  • CE-RED 是什么?
  • Win10上VScode 进行ssh登录服务器时免密登录
  • AWS Global Accelerator 详解:比传统 CDN 更快的全球加速方案
  • Apollo学习之预测模块二