当前位置: 首页 > news >正文

group_points自定义tensorrt算子编写

导出onnx模型

import torch
import numpy as np
import pointnet2_utilsclass CustomModel(torch.nn.Module):def __init__(self):super(CustomModel, self).__init__()def forward(self, features, idx):tmp = pointnet2_utils.grouping_operation(features, idx)  #torch.Size([1, 3, 2048, 64])return pointnet2_utils.grouping_operation(features, idx)model = CustomModel().cuda()
features = torch.randn(1, 3, 20000).cuda()  
idx = torch.randn(1, 2048, 64).cuda().to(torch.int32)
np.savetxt("features.txt", features.reshape(3, 20000).detach().cpu().numpy())
np.savetxt("idx.txt", idx.reshape(2048, 64).detach().cpu().numpy())torch.onnx.export(model, (features, idx), "grouping_operation.onnx", opset_version=13)

其中pointnet2_utils来自https://github.com/erikwijmans/Pointnet2_PyTorch。
导出onnx模型结构如下:在这里插入图片描述
在这里插入图片描述

编写tensorrt插件

采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建groupPoints文件夹,添加下面文件:
groupPoints.h

#ifndef TRT_GROUP_POINTS_PLUGIN_H
#define TRT_GROUP_POINTS_PLUGIN_H#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/cuda_utils.h"#include <vector>
#include <cstring>namespace nvinfer1
{
namespace plugin
{
void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, const float *points, const int *idx, float *out, cudaStream_t stream);class GroupPoints : public nvinfer1::IPluginV2DynamicExt 
{
public:GroupPoints();GroupPoints(void const* data, size_t length);~GroupPoints() override;// 插件基本信息char const* getPluginType() const noexcept override;char const* getPluginVersion() const noexcept override;int getNbOutputs() const noexcept override;// 输出维度计算nvinfer1::DimsExprs getOutputDimensions(int outputIndex,const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;// 初始化与销毁int initialize() noexcept override;void terminate() noexcept override;// 执行相关int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,int nbInputs, const nvinfer1::PluginTensorDesc* outputs,int nbOutputs) const noexcept override;// 数据类型与格式支持DataType getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept override;bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,int nbOutputs) noexcept override;// 配置插件void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,const  nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;// 序列化size_t getSerializationSize() const noexcept override;void serialize(void* buffer) const noexcept override;  // 其他接口nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;void destroy() noexcept override;void setPluginNamespace(char const* libNamespace) noexcept override;char const* getPluginNamespace() const noexcept override;private:int mNumSamples;  // 采样点数(从输入获取)std::string mPluginNamespace;Dims mInputDims;    // 点云输入维度 (B, N, 3)Dims mSampleDims;   // 采样点数输入维度(通常是标量或 (B,))
};class GroupPointsCreator : public nvinfer1::IPluginCreator 
{
public:GroupPointsCreator();~GroupPointsCreator() override = default;char const* getPluginName() const noexcept override;char const* getPluginVersion() const noexcept override;PluginFieldCollection const* getFieldNames() noexcept override;IPluginV2* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override;IPluginV2* deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept override;void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;const char* getPluginNamespace() const noexcept override;private:static PluginFieldCollection mFC;static std::vector<PluginField> mPluginAttributes;std::string mNamespace;
};} // namespace plugin
} // namespace nvinfer1#endif // TRT_GROUP_POINTS_PLUGIN_H

groupPoints.cpp

/** SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.* SPDX-License-Identifier: Apache-2.0*/
#include "groupPoints.h"
#include "common/dimsHelpers.h"using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::GroupPoints;
using nvinfer1::plugin::GroupPointsCreator;// 插件实现
GroupPoints::GroupPoints() 
{}GroupPoints::GroupPoints(void const* data, size_t length) 
{}GroupPoints::~GroupPoints() {}// 插件基本信息
char const* GroupPoints::getPluginType() const noexcept 
{ //std::cout<<"getPluginType"<<std::endl;return "group_points"; 
}char const* GroupPoints::getPluginVersion() const noexcept 
{ //std::cout<<"getPluginVersion"<<std::endl;return "1"; 
}int GroupPoints::getNbOutputs() const noexcept 
{ //std::cout<<"getNbOutputs"<<std::endl;return 1; 
}nvinfer1::DimsExprs GroupPoints::getOutputDimensions(int outputIndex,const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{//std::cout << "getOutputDimensions" << std::endl;// 验证输出索引和输入数量PLUGIN_ASSERT(outputIndex == 0 && nbInputs == 2);// 构建输出维度: (B, M)nvinfer1::DimsExprs outputDims;outputDims.nbDims = 4;// 第一个维度为批次大小 B (与输入保持一致)outputDims.d[0] = exprBuilder.constant(static_cast<int>(inputs[0].d[0]->getConstantValue()));outputDims.d[1] = exprBuilder.constant(static_cast<int>(inputs[0].d[1]->getConstantValue()));outputDims.d[2] = exprBuilder.constant(static_cast<int>(inputs[1].d[1]->getConstantValue()));outputDims.d[3] = exprBuilder.constant(static_cast<int>(inputs[1].d[2]->getConstantValue()));return outputDims;
}// 初始化
int GroupPoints::initialize() noexcept 
{ return STATUS_SUCCESS; 
}// 销毁资源
void GroupPoints::terminate() noexcept {}// 执行核函数
int GroupPoints::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{//std::cout << "enqueue" << std::endl;try {PLUGIN_ASSERT(inputDesc[0].dims.nbDims == 3);PLUGIN_ASSERT(inputDesc[1].dims.nbDims == 3);  PLUGIN_ASSERT(outputDesc[0].dims.nbDims == 4); PLUGIN_ASSERT(inputDesc[0].type == nvinfer1::DataType::kFLOAT);    PLUGIN_ASSERT(inputDesc[1].type == nvinfer1::DataType::kINT32);    PLUGIN_ASSERT(outputDesc[0].type == nvinfer1::DataType::kFLOAT);  // 从输入描述中提取维度信息const int b = inputDesc[0].dims.d[0];            // 批次大小const int c = inputDesc[0].dims.d[1];            // 点坐标维度 (通常为3)const int n = inputDesc[0].dims.d[2];            // 原始点数量const int npoints = inputDesc[1].dims.d[1];      // 采样点数量 (M)const int nsample = inputDesc[1].dims.d[2];      // 每个采样点的邻域点数 (K)// 获取输入输出数据指针const float* points = static_cast<const float*>(inputs[0]);  // 点云数据: (B, 3, N)const int* idx = static_cast<const int*>(inputs[1]);         // 索引数据: (B, M, K)float* out = static_cast<float*>(outputs[0]);                // 输出数据: (B, 3, M, K)// 调用核函数包装器执行分组操作group_points_kernel_wrapper(b, c, n, npoints, nsample, points, idx, out, stream);return STATUS_SUCCESS;} catch (std::exception const& e) {caughtError(e);}return -1;
}// 工作空间大小
size_t GroupPoints::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,int nbInputs, const nvinfer1::PluginTensorDesc* outputs,int nbOutputs) const noexcept 
{ return 0;
}// 输出数据类型:索引为INT32
DataType GroupPoints::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept 
{//std::cout<<"getOutputDataType"<<std::endl;  PLUGIN_ASSERT(index == 0 && nbInputs == 2);return DataType::kFLOAT;
}// 支持的格式:输入float32,输出int32,均为线性格式
bool GroupPoints::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{//std::cout << "supportsFormatCombination" << std::endl;//return true;// 插件有2个输入和1个输出PLUGIN_ASSERT(pos < nbInputs + nbOutputs);if (pos == 0){return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 1){return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}else if (pos == 2){return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);}return false;
}void GroupPoints::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{//std::cout<<"configurePlugin"<<std::endl;try {PLUGIN_ASSERT(nbInputs == 2 && nbOutputs == 1); // 确认2个输入和1个输出} catch (std::exception const& e) {caughtError(e);}
}size_t GroupPoints::getSerializationSize() const noexcept 
{ //std::cout<<"getSerializationSize"<<std::endl;return 0; 
}void GroupPoints::serialize(void* buffer) const noexcept 
{//std::cout<<"serialize"<<std::endl;
}// 克隆插件
nvinfer1::IPluginV2DynamicExt* GroupPoints::clone() const noexcept 
{//std::cout<<"clone"<<std::endl;try{return new GroupPoints();}catch (std::exception const& e){caughtError(e);}return nullptr;
}void GroupPoints::destroy() noexcept 
{ delete this; 
}// 命名空间管理
void GroupPoints::setPluginNamespace(char const* pluginNamespace) noexcept 
{ //std::cout<<"setPluginNamespace"<<std::endl;try{mPluginNamespace = pluginNamespace;}catch (std::exception const& e){caughtError(e);}
}char const* GroupPoints::getPluginNamespace() const noexcept 
{ //std::cout<<"getPluginNamespace"<<std::endl;return mPluginNamespace.c_str(); 
}// 插件创建器实现
PluginFieldCollection GroupPointsCreator::mFC{};
std::vector<PluginField> GroupPointsCreator::mPluginAttributes;GroupPointsCreator::GroupPointsCreator() 
{//std::cout<<"GroupPointsCreator"<<std::endl;mPluginAttributes.clear();mFC.nbFields = mPluginAttributes.size();mFC.fields = mPluginAttributes.data();
}char const* GroupPointsCreator::getPluginName() const noexcept 
{ //std::cout<<"getPluginName"<<std::endl;return "group_points"; 
}char const* GroupPointsCreator::getPluginVersion() const noexcept 
{ //std::cout<<"getPluginVersion"<<std::endl;return "1"; 
}PluginFieldCollection const* GroupPointsCreator::getFieldNames() noexcept 
{ //std::cout<<"getFieldNames"<<std::endl;return &mFC; 
}IPluginV2* GroupPointsCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept 
{//std::cout<<"createPlugin"<<std::endl;try {return new GroupPoints();}catch (std::exception const& e) {caughtError(e);}return nullptr;
}IPluginV2* GroupPointsCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept 
{//std::cout<<"deserializePlugin"<<std::endl;try{// This object will be deleted when the network is destroyed, which will// call Concat::destroy()IPluginV2Ext* plugin = new GroupPoints();plugin->setPluginNamespace(mNamespace.c_str());return plugin;}catch (std::exception const& e){caughtError(e);}return nullptr;
}void GroupPointsCreator::setPluginNamespace(char const* libNamespace) noexcept
{  //std::cout<<"setPluginNamespace"<<std::endl;        mNamespace = libNamespace;
}char const* GroupPointsCreator::getPluginNamespace() const noexcept
{    //std::cout<<"getPluginNamespace"<<std::endl;      return mNamespace.c_str();
}

groupPoints.cu

#include <stdio.h>
#include <stdlib.h>#include "NvInfer.h"
#include "groupPoints.h"
#include <cuda_runtime.h>namespace nvinfer1
{
namespace plugin
{// input: points(b, c, n) idx(b, npoints, nsample)
// output: out(b, c, npoints, nsample)
__global__ void group_points_kernel(int b, int c, int n, int npoints,int nsample,const float *__restrict__ points,const int *__restrict__ idx,float *__restrict__ out) {int batch_index = blockIdx.x;points += batch_index * n * c;idx += batch_index * npoints * nsample;out += batch_index * npoints * nsample * c;const int index = threadIdx.y * blockDim.x + threadIdx.x;const int stride = blockDim.y * blockDim.x;for (int i = index; i < c * npoints; i += stride) {const int l = i / npoints;const int j = i % npoints;for (int k = 0; k < nsample; ++k) {int ii = idx[j * nsample + k];out[(l * npoints + j) * nsample + k] = points[l * n + ii];}}
}void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,const float *points, const int *idx,float *out, cudaStream_t stream) {group_points_kernel<<<b, opt_block_config(npoints, c), 0, stream>>>(b, c, n, npoints, nsample, points, idx, out);CUDA_CHECK_ERRORS();
}}
}

CMakeLists.txt

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

在TensorRT/plugin/inferPlugin.cpp的开头添加

#include "groupPoints/groupPoints.h"

并在initLibNvInferPlugins函数中添加

initializePlugin<nvinfer1::plugin::GroupPointsCreator>(logger, libNamespace);

在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加

groupPoints

在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR、TRT_OUT_DIR,再重新编译tensorrt。

tensorrt推理测试

运行下面的命令把onnx 转为engine模型:

TensorRT-10.6.0.26/bin/trtexec --onnx=grouping_operation.onnx --saveEngine=grouping_operation.engine

编写python推理脚本:

import numpy as np
import tensorrt as trt
import commonlogger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("grouping_operation.engine", "rb") as f, trt.Runtime(logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)features = np.loadtxt("features.txt")
idx = np.loadtxt("idx.txt")
features = features.reshape(1, 3, 20000).astype(np.float32)
idx = idx.reshape(2048, 64).astype(np.int32)
np.copyto(inputs[0].host, features.ravel())
np.copyto(inputs[1].host, idx.ravel())output = common.do_inference(context,engine=engine, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(output)   
http://www.dtcms.com/a/398313.html

相关文章:

  • 20250925问答课题-多标签分类模型
  • 唯品会库存API集成问题与技术方案解析
  • Python开发一个系统
  • 02-教务管理系统(选课管理系统)
  • 从入门到精通:逆向工程完全工具指南与桌面环境搭建
  • 注册网站做推广衡阳网站搜索引擎优化
  • 从零开始学Flink:数据转换的艺术
  • 公司做网站的流程wordpress 放大镜插件
  • 《系统与软件工程 功能规模测量 NESMA方法》(GBT 42588-2023)标准解读
  • React Testing完全指南:Jest、React Testing Library实战
  • python+springboot+django/flask的医院食堂订餐系统 菜单发布 在线订餐 餐品管理与订单统计系统
  • 半导体制造常见检测之拉曼光谱
  • Python 第七节 循环语句for和while使用详解及注意事项
  • 怎么把svg做网站背景谷歌关键词挖掘工具
  • Vue3中的computed属性
  • 7. 临时变量的常量性
  • SNK施努卡有色冶炼自动化解决方案
  • SpringCloud项目阶段七:延迟任务技术选项对比以及接入redis实现延迟队列添加/取消/消费等任务
  • 建站特别慢wordpress网站项目总体设计模板
  • 驱动开发,为什么需要映射?
  • 网站栏目模版确定网站推广目标
  • AI产品经理项目实战:BERT语义分析识别重复信息
  • 亚远景-ISO 42001:为汽车AI安全设定新标杆
  • 电路方案分析(二十四)汽车高压互锁参考设计
  • 深圳网站快速备案手机app播放器
  • CSS精灵技术
  • 数据库导论#1
  • Web应用接入支付功能的准备工作和开发规范
  • 专业做logo的网站wordpress安装模板
  • 8 shiro的web整合