pytorch自定义算子转tensorrt
pytorch导出onnx模型
设有模型:
import torch
import torch.nn as nnclass MYPLUGINImpl(torch.autograd.Function):@staticmethoddef symbolic(g, x, p):return g.op("MYPLUGIN", x, p,g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.int32)),attr1_s="这是字符串属性",attr2_i=[1, 2],attr3_f=222)@staticmethoddef forward(ctx, x, a):return x + aclass MYPLUGIN(nn.Module):def __init__(self, n):super().__init__()self.param = nn.parameter.Parameter(torch.arange(n).float())def forward(self, x, a):return MYPLUGINImpl.apply(x, a)class Model(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 1, 3, padding=1)self.conv.weight.data.fill_(1)self.conv.bias.data.fill_(0)self.myplugin = MYPLUGIN(3)def forward(self, x, a):x = self.conv(x)x = self.myplugin(x, a)return xmodel = Model().eval()
input = torch.tensor([[[1, 1, 1],[1, 1, 1],[1, 1, 1],],
], dtype=torch.float32).view(1, 1, 3, 3)
a = torch.tensor(2, dtype=torch.int32)
output = model(input, a)
print(f"inference output = {output}")torch.onnx.export(model, # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号(input, a),"myplugin.onnx", # 储存的文件路径input_names=["image"], # 为输入和输出节点指定名称,方便后面查看或者操作output_names=["output"],opset_version=13,# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,#verbose=True,# 打印详细信息# 表示有batch、height、width3个维度是动态的,在onnx中给其赋值为-1,通常,我们只设置batch为动态,其他的避免动态# dynamic_axes={ # "image": {# 0: "batch", 2: "height", 3: "width"},# "output": {# 0: "batch", 2: "height", 3: "width"},# },
)
程序输出:
inference output =
tensor([[[[ 6., 8., 6.],[ 8., 11., 8.],[ 6., 8., 6.]]]], grad_fn=<MYPLUGINImplBackward>)
导出的onnx模型结构如下:
编写tensorrt自定义算子
采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建myPlugin文件夹,添加下面文件:
myPlugin.h
#ifndef TRT_MYPLUGIN_H
#define TRT_MYPLUGIN_H#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include <string>
#include <vector>namespace nvinfer1
{
namespace plugin
{
class myPlugin : public nvinfer1::IPluginV2DynamicExt
{
public:myPlugin(const std::string name, const std::string attr1, float attr3); // 接受算子名称属性,build engine时构造函数myPlugin(const std::string name, const void* data, size_t length); // 接受算子名称和反序列化的engine data,推理时构造函数int getNbOutputs() const noexcept override;virtual nvinfer1::DataType getOutputDataType(int32_t index,nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override { return inputTypes[0];}virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;int initialize() noexcept override;void terminate() noexcept override;virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,int32_t nbOutputs) const noexcept override { return 0;};int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;size_t getSerializationSize() const noexcept override;void serialize(void* buffer) const noexcept override;virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,int32_t nbOutputs) noexcept override;const char* getPluginType() const noexcept override;const char* getPluginVersion() const noexcept override;void destroy() noexcept override;nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;const char* getPluginNamespace()const noexcept override;private:const std::string mLayerName;std::string mattr1;float mattr3;size_t mInputVolume;std::string mNamespace;
};class myPluginCreator : public nvinfer1::IPluginCreator
{
public:myPluginCreator();const char* getPluginName() const noexcept override;const char* getPluginVersion() const noexcept override;const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,nvinfer1::PluginFieldCollection const* fc) noexcept override;nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,void const* serialData, size_t serialLength)noexcept override;void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;const char* getPluginNamespace() const noexcept override;
private:static nvinfer1::PluginFieldCollection mfc;static std::vector<nvinfer1::PluginField> mPluginAttributes;std::string mNamespace;
};
}
}#endif // TRT_MYPLUGIN_H
myPlugin.cpp
#include "myPlugin.h"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::myPlugin;
using nvinfer1::plugin::myPluginCreator;namespace nvinfer1
{
namespace plugin
{
void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream);// 静态类字段的初始化
nvinfer1::PluginFieldCollection myPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> myPluginCreator::mPluginAttributes;// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{ *reinterpret_cast<T*>(buffer) = val;buffer += sizeof(T);
}// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(char const*& buffer)
{ T val = *reinterpret_cast<const T*>(buffer);buffer += sizeof(T);return val;
}// 定义插件类MYPlugin
myPlugin::myPlugin(const std::string name, const std::string attr1, float attr3):mLayerName(name), mattr1(attr1), mattr3(attr3)
{std::cout<<"myPlugin"<<std::endl;
};myPlugin::myPlugin(const std::string name, const void* data, size_t length) :mLayerName(name)
{std::cout<<"myPlugin()"<<std::endl; // Deserialize in the same order as serializationchar const* d = static_cast<char const*>(data);char const* a = d;int nstr = readFromBuffer<int>(d);mattr1 = std::string(d, d + nstr);d += nstr;mattr3 = readFromBuffer<float>(d);assert(d == (a + length));
};char const* myPlugin::getPluginType() const noexcept
{ std::cout<<"getPluginType"<<std::endl; return "MYPLUGIN";
}char const* myPlugin::getPluginVersion() const noexcept
{std::cout<<"getPluginVersion"<<std::endl; return "1";
}int myPlugin::getNbOutputs() const noexcept
{std::cout<<"getNbOutputs"<<std::endl; return 1;
}// 获取该层的输出维度是多少
nvinfer1::DimsExprs myPlugin::getOutputDimensions(int32_t outputIndex,const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{std::cout<<"getOutputDimensions"<<std::endl; //不改变输入尺寸,所以输出尺寸将与输入尺寸相同return inputs[0];
}int myPlugin::initialize() noexcept
{std::cout<<"initialize"<<std::endl; return 0;
}int myPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{std::cout<<"enqueue"<<std::endl; void* output = outputs[0];size_t volume = 1;for (int i = 0; i < inputDesc->dims.nbDims; ++i){ volume *= inputDesc->dims.d[i];}mInputVolume = volume;myselu_inference(static_cast<const float*>(inputs[0]),static_cast<const int*>(inputs[1]),static_cast<float*>(output),mInputVolume,stream);return 0;
}size_t myPlugin::getSerializationSize() const noexcept
{std::cout<<"getSerializationSize"<<std::endl; return sizeof(int) + mattr1.size() + sizeof(mattr3);
}// 该层的参数序列化储存为trtmodel文件
void myPlugin::serialize(void* buffer) const noexcept
{std::cout<<"serialize"<<std::endl; char* d = static_cast<char*>(buffer);char const* a = d;int nstr = mattr1.size();writeToBuffer(d, nstr);memcpy(d, mattr1.data(), nstr);d += nstr;writeToBuffer(d, mattr3);assert(d == a + getSerializationSize());
}// 判断该插件所支持的数据格式和类型
bool myPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,int32_t nbOutputs) noexcept
{std::cout<<"supportsFormatCombination"<<std::endl; PLUGIN_ASSERT(pos < nbInputs + nbOutputs);if (pos == 0){return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 1){return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 2){return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}return true;
}void myPlugin::terminate() noexcept { }void myPlugin::destroy() noexcept
{ // This gets called when the network containing plugin is destroyeddelete this;
}// 配置插件格式:目前这个层所采用的数据格式和类型
void myPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{std::cout<<"configurePlugin"<<std::endl; try {PLUGIN_ASSERT(nbInputs == 3 && nbOutputs == 1); // 确认3个输入和1个输出} catch (std::exception const& e) {caughtError(e);}
}// 克隆插件
nvinfer1::IPluginV2DynamicExt* myPlugin::clone() const noexcept
{std::cout<<"clone"<<std::endl; auto plugin = new myPlugin(mLayerName, mattr1, mattr3);plugin->setPluginNamespace(mNamespace.c_str());return plugin;
}void myPlugin::setPluginNamespace(char const* libNamespace) noexcept
{std::cout<<"setPluginNamespace"<<std::endl; mNamespace = libNamespace;
}char const* myPlugin::getPluginNamespace() const noexcept
{std::cout<<"getPluginNamespace"<<std::endl; return mNamespace.c_str();
}// 插件创建器
myPluginCreator::myPluginCreator()
{std::cout<<"myPluginCreator"<<std::endl; // 描述myPlugin的必要PluginField参数mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));// 收集PluginField的参数mfc.nbFields = mPluginAttributes.size();mfc.fields = mPluginAttributes.data();
}char const* myPluginCreator::getPluginName() const noexcept
{ std::cout<<"getPluginName"<<std::endl; return "MYPLUGIN";
}char const* myPluginCreator::getPluginVersion() const noexcept
{std::cout<<"getPluginVersion"<<std::endl; return "1";
}const nvinfer1::PluginFieldCollection* myPluginCreator::getFieldNames() noexcept
{std::cout<<"getFieldNames"<<std::endl; return &mfc;
}// 创建plugin
nvinfer1::IPluginV2* myPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,nvinfer1::PluginFieldCollection const* fc) noexcept
{std::cout<<"createPlugin"<<std::endl; std::string attr1;float attr3;const nvinfer1::PluginField* fields = fc->fields;// Parse fields from PluginFieldCollectionfor (int i = 0; i < fc->nbFields; ++i){ if (strcmp(fields[i].name, "attr1")==0) { assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);auto cp = static_cast<char const*>(fields[i].data);attr1 = std::string(cp, cp + fields[i].length);}else if (strcmp(fields[i].name, "attr3") == 0) { assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);attr3 = *(static_cast<const float*>(fields[i].data));}}return new myPlugin(name, attr1, attr3);
}// 反序列化插件参数进行创建
nvinfer1::IPluginV2* myPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,void const* serialData, size_t serialLength)noexcept
{ std::cout<<"deserializePlugin"<<std::endl; // This object will be deleted when the network is destroyed, which will// call myPlugin::destroy()return new myPlugin(name, serialData, serialLength);
}void myPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{ std::cout<<"setPluginNamespace"<<std::endl; mNamespace = libNamespace;
}char const* myPluginCreator::getPluginNamespace() const noexcept
{ std::cout<<"getPluginNamespace"<<std::endl; return mNamespace.c_str();
}
}
}
myPlugin.cu
#include "NvInfer.h"
#include <cuda_runtime.h>namespace nvinfer1
{
namespace plugin
{
static __global__ void myselu_kernel(const float* x, const int* a, float* output, int n)
{int position = threadIdx.x + blockDim.x*blockIdx.x;if (position >= n) return;output[position] = x[position] + a[0];
}void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream)
{ const int nthreads = 512;int block_size = n > nthreads ? nthreads : n;int grid_size = (n + block_size - 1) / block_size;myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, a, output, n);
}
}
}
CMakeLists.txt
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
在TensorRT/plugin/inferPlugin.cpp的开头添加
#include "myPlugin/myPlugin.h"
并在initLibNvInferPlugins函数中添加
initializePlugin<nvinfer1::plugin::myPluginCreator>(logger, libNamespace);
在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加
myPlugin
在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR
、TRT_OUT_DIR
,再重新编译tensorrt。
tensorrt推理测试
运行下面的命令把onnx 转为engine模型:
TensorRT-10.6.0.26/bin/trtexec --onnx=myplugin.onnx --saveEngine=myplugin.engine
编写python推理脚本:
import numpy as np
import tensorrt as trt
import commonlogger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("myplugin.engine", "rb") as f, trt.Runtime(logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)input = np.ones((3, 3))
a = 2
np.copyto(inputs[0].host, input.ravel())
np.copyto(inputs[1].host, a)output = common.do_inference(context,engine=engine, bindings=bindings,inputs=inputs, outputs=outputs, stream=stream)
print(output)