TensorRT笔记(1):自定义MNIST数据集推理类
该示例是TensorRT官方认证的TensorRT版Hello World,原始代码可见官方GitHub:
TensorRT/samples/sampleOnnxMNIST at main · NVIDIA/TensorRT
阅读之前,可以先看cuda编程笔记(21)-- TensorRT-CSDN博客,了解TensorRT最基础的执行流程。本例就是对该流程的封装。在前文已经介绍过的api,本文就不再介绍了。
注意:示例代码里用到了许多common文件夹里的官方定义的类,在介绍时会一并介绍。自己如果想要运行也要注意这些头文件和cpp文件,以免编译不过。
头文件和全局变量
#include "argsParser.h"
#include "buffers.h"
#include "common.h"
#include "logger.h"
#include "parserOnnxConfig.h"#include "NvInfer.h"
#include <cuda_runtime_api.h>#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
using namespace nvinfer1;
//其实就是unique_ptr
using samplesCommon::SampleUniquePtr;
//这个字符串是给日志用的
const std::string gSampleName = "TensorRT.sample_onnx_mnist";
类的声明
class SampleOnnxMNIST{
public:SampleOnnxMNIST(const samplesCommon::OnnxSampleParams¶ms):mParams(params),mRuntime(nullptr),mEngine(nullptr){}//!//! \brief 构建部分//!bool build();//!//! \brief 推理部分//!bool infer();
private:samplesCommon::OnnxSampleParams mParams;//参数nvinfer1::Dims mInputDims,mOutputDims;//输入输出的形状int mNumber{0};//分类结果std::shared_ptr<nvinfer1::IRuntime> mRuntime;//TensorRT 的运行时环境,负责把序列化好的 engine 反序列化成 ICudaEnginestd::shared_ptr<nvinfer1::ICudaEngine> mEngine;//表示一个 TensorRT 编译好的网络模型(不可修改)。它保存了网络结构、权重、输入输出 Tensor 的信息。//!//! \brief 将一个onnx解析成TennsorRT的网络(network)//!bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder> &builder,SampleUniquePtr<nvinfer1::INetworkDefinition>&network,SampleUniquePtr<nvinfer1::IBuilderConfig> &config,SampleUniquePtr<nvonnxparser::IParser> &parser,SampleUniquePtr<nvinfer1::ITimingCache>&timingCache);//!//! \brief Reads the input and stores the result in a managed buffer//!bool processInput(const samplesCommon::BufferManager&buffers);//!//! \brief Classifies digits and verify result//!bool verifyOutput(const samplesCommon::BufferManager&buffers);
};
samplesCommon::OnnxSampleParams
该类定义在
#include "argsParser.h"
struct OnnxSampleParams : public SampleParams
{std::string onnxFileName; //!< Filename of ONNX file of a network
};
明显,该类是SampleParams的子类,仅仅额外多了一个onnx文件的名称的成员变量
struct SampleParams
{int32_t batchSize{1}; //指定推理时的 batch sizeint32_t dlaCore{-1}; //是否使用 NVIDIA DLA (Deep Learning Accelerator)-1 表示不使用 DLA,仅使用 GPU。bool int8{false}; //是否允许 TensorRT 使用 INT8 量化推理模式bool fp16{false}; //是否启用 半精度 FP16 模式bool bf16{false}; //是否启用 BF16 模式(Brain Floating Point 16)std::vector<std::string> dataDirs; //存放样例数据的目录。用于存放输入图片、模型文件、cache 文件等std::vector<std::string> inputTensorNames;//指定模型的 输入 Tensor 名称,必须与你导出 ONNX 时定义的名称一致std::vector<std::string> outputTensorNames;//模型的 输出 Tensor 名称,必须与 ONNX 模型输出名称对应std::string timingCacheFile; //指定 TensorRT 的 Timing Cache 文件路径
};
nvinfer1::Dims
它是 TensorRT 中用于描述 Tensor 的维度信息(shape) 的类。
例如在 PyTorch/NumPy 中我们常写:
x.shape == (1, 3, 224, 224)
而在 TensorRT 中,这个 shape 就会用一个 nvinfer1::Dims 对象来表示。
class Dims64
{
public://! The maximum rank (number of dimensions) supported for a tensor.static constexpr int32_t MAX_DIMS{8};//! The rank (number of dimensions).//维度个数int32_t nbDims;//! The extent of each dimension.//各个维度的大小int64_t d[MAX_DIMS];
};//!
//! Alias for Dims64.
//!
using Dims = Dims64;
假设你要表示一个输入张量形状为 (1, 3, 224, 224)(即 batch=1, 3通道, 224x224 图像):
nvinfer1::Dims inputDims;
inputDims.nbDims = 4;
inputDims.d[0] = 1; // N
inputDims.d[1] = 3; // C
inputDims.d[2] = 224; // H
inputDims.d[3] = 224; // W
在自定义network的时候,有可能需要传入
TensorRT 里其实有几种 Dims 派生类,针对不同情况:
| 类型 | 说明 | 常用场景 |
|---|---|---|
nvinfer1::Dims | 通用维度结构体 | 通用接口 |
nvinfer1::Dims2 | 固定 2 维(d[0], d[1]) | 2D 层,如 addFullyConnected |
nvinfer1::Dims3 | 固定 3 维 | 卷积核尺寸等 |
nvinfer1::Dims4 | 固定 4 维 | 常见 NCHW 输入 |
nvinfer1::DimsHW | Height×Width | 卷积/池化核大小或步长 |
nvinfer1::DimsCHW | 通道、高度、宽度 | 图像类输入 |
Dims 只描述形状,不描述类型(float/int8等)。
类型信息由 nvinfer1::DataType 表示:
enum class DataType { kFLOAT, kHALF, kINT8, kBOOL };
constructNetwork
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,SampleUniquePtr<nvonnxparser::IParser>& parser, SampleUniquePtr<nvinfer1::ITimingCache>& timingCache){auto parsed=parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),static_cast<int>(sample::gLogger.getReportableSeverity())); if(!parsed){return false;}if(mParams.fp16){config->setFlag(BuilderFlag::kBF16);}if(mParams.bf16){config->setFlag(BuilderFlag::kBF16);}if(mParams.int8){config->setFlag(BuilderFlag::kINT8);network->getInput(0)->setDynamicRange(-1.0F,1.0F);constexpr float KTENSOR_DYNAMIC_RANGE=4.0F;samplesCommon::setAllDynamicRanges(network.get(),KTENSOR_DYNAMIC_RANGE,KTENSOR_DYNAMIC_RANGE);}if(mParams.timingCacheFile.size()){timingCache=samplesCommon::buildTimingCacheFromFile(sample::gLogger.getTRTLogger(),*config,mParams.timingCacheFile,std::cerr);}samplesCommon::enableDLA(builder.get(),config.get(),mParams.dlaCore);
}
sample::Logger
TensorRT 要求用户必须提供一个 nvinfer1::ILogger 类型的对象;所以官方在样例里就给我们实现了一个子类
class Logger : public nvinfer1::ILogger
{
public:explicit Logger(Severity severity = Severity::kWARNING): mReportableSeverity(severity){}//!//! \enum TestResult//! \brief Represents the state of a given test//!enum class TestResult{kRUNNING, //!< The test is runningkPASSED, //!< The test passedkFAILED, //!< The test failedkWAIVED //!< The test was waived};//获取原生 TensorRT Logger//方便把当前对象直接传进 TensorRT 的 APInvinfer1::ILogger& getTRTLogger() noexcept{return *this;}//!//! \brief Implementation of the nvinfer1::ILogger::log() virtual method//!//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the//! inheritance from nvinfer1::ILogger//!void log(Severity severity, const char* msg) noexcept override{LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;}//!//! \brief Method for controlling the verbosity of logging output//!//! \param severity The logger will only emit messages that have severity of this level or higher.//!//允许在运行时更改输出级别。void setReportableSeverity(Severity severity) noexcept{mReportableSeverity = severity;}//!//! \brief Opaque handle that holds logging information for a particular test//!//! This object is an opaque handle to information used by the Logger to print test results.//! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used//! with Logger::reportTest{Start,End}().//!class TestAtom{public:TestAtom(TestAtom&&) = default;private:friend class Logger;TestAtom(bool started, const std::string& name, const std::string& cmdline): mStarted(started), mName(name), mCmdline(cmdline){}bool mStarted;std::string mName;std::string mCmdline;};//!//! \brief Define a test for logging//!//! \param[in] name The name of the test. This should be a string starting with//! "TensorRT" and containing dot-separated strings containing//! the characters [A-Za-z0-9_].//! For example, "TensorRT.sample_googlenet"//! \param[in] cmdline The command line used to reproduce the test////! \return a TestAtom that can be used in Logger::reportTest{Start,End}().//!static TestAtom defineTest(const std::string& name, const std::string& cmdline){return TestAtom(false, name, cmdline);}//!//! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments//! as input//!//! \param[in] name The name of the test//! \param[in] argc The number of command-line arguments//! \param[in] argv The array of command-line arguments (given as C strings)//!//! \return a TestAtom that can be used in Logger::reportTest{Start,End}().//!static TestAtom defineTest(const std::string& name, int32_t argc, char const* const* argv){// Append TensorRT version as infoconst std::string vname = name + " [TensorRT v" + std::to_string(NV_TENSORRT_VERSION) + "]";auto cmdline = genCmdlineString(argc, argv);return defineTest(vname, cmdline);}//!//! \brief Report that a test has started.//!//! \pre reportTestStart() has not been called yet for the given testAtom//!//! \param[in] testAtom The handle to the test that has started//!static void reportTestStart(TestAtom& testAtom){reportTestResult(testAtom, TestResult::kRUNNING);assert(!testAtom.mStarted);testAtom.mStarted = true;}//!//! \brief Report that a test has ended.//!//! \pre reportTestStart() has been called for the given testAtom//!//! \param[in] testAtom The handle to the test that has ended//! \param[in] result The result of the test. Should be one of TestResult::kPASSED,//! TestResult::kFAILED, TestResult::kWAIVED//!static void reportTestEnd(TestAtom const& testAtom, TestResult result){assert(result != TestResult::kRUNNING);assert(testAtom.mStarted);reportTestResult(testAtom, result);}static int32_t reportPass(TestAtom const& testAtom){reportTestEnd(testAtom, TestResult::kPASSED);return EXIT_SUCCESS;}static int32_t reportFail(TestAtom const& testAtom){reportTestEnd(testAtom, TestResult::kFAILED);return EXIT_FAILURE;}static int32_t reportWaive(TestAtom const& testAtom){reportTestEnd(testAtom, TestResult::kWAIVED);return EXIT_SUCCESS;}static int32_t reportTest(TestAtom const& testAtom, bool pass){return pass ? reportPass(testAtom) : reportFail(testAtom);}Severity getReportableSeverity() const{return mReportableSeverity;}private://!//! \brief returns an appropriate string for prefixing a log message with the given severity//!static const char* severityPrefix(Severity severity){switch (severity){case Severity::kINTERNAL_ERROR: return "[F] ";case Severity::kERROR: return "[E] ";case Severity::kWARNING: return "[W] ";case Severity::kINFO: return "[I] ";case Severity::kVERBOSE: return "[V] ";default: assert(0); return "";}}//!//! \brief returns an appropriate string for prefixing a test result message with the given result//!static const char* testResultString(TestResult result){switch (result){case TestResult::kRUNNING: return "RUNNING";case TestResult::kPASSED: return "PASSED";case TestResult::kFAILED: return "FAILED";case TestResult::kWAIVED: return "WAIVED";default: assert(0); return "";}}//!//! \brief returns an appropriate output stream (cout or cerr) to use with the given severity//!static std::ostream& severityOstream(Severity severity){return severity >= Severity::kINFO ? std::cout : std::cerr;}//!//! \brief method that implements logging test results//!static void reportTestResult(TestAtom const& testAtom, TestResult result){severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "<< testAtom.mCmdline << std::endl;}//!//! \brief generate a command line string from the given (argc, argv) values//!static std::string genCmdlineString(int32_t argc, char const* const* argv){std::stringstream ss;for (int32_t i = 0; i < argc; i++){if (i > 0){ss << " ";}ss << argv[i];}return ss.str();}Severity mReportableSeverity;
}; // class Logger
| 功能 | 接口 | 说明 |
|---|---|---|
| 日志接口实现 | log() | TensorRT 内部回调,打印运行日志 |
| 获取原生 Logger | getTRTLogger() | 给 TensorRT API 使用 |
| 设置日志级别 | setReportableSeverity() | 控制打印详细度 |
| 定义测试 | defineTest() | 创建 TestAtom |
| 测试开始 | reportTestStart() | 打印 &&&& RUNNING ... |
| 测试结束 | reportPass()/reportFail() | 打印结果并返回 exit code |
| 日志输出流 | severityOstream() | 自动选择 cout/cerr |
| 辅助字符串 | severityPrefix() | 加上 [E], [W] 等前缀 |
sample::gLogger就是该类实例化的一个对象
Logger gLogger{Logger::Severity::kINFO};
我们使用的时候,就用传入这个对象即可(可以用getTRTLogger()获取基类型引用)
这个类就不介绍太多,因为有很多封装,可以后面专门研究。
locateFile
在官方案例里是在作用域下的samplesCommon::locateFile
但是我本地的版本并没有在作用域下。
inline std::string locateFile(const std::string& filepathSuffix, const std::vector<std::string>& directories, bool reportError = true);
功能很简单,就是在directories这几个目录下找你给出的filepathSuffix文件,返回完整路径名称
我们的使用如下
locateFile(mParams.onnxFileName, mParams.dataDirs)
会在参数里给出的路径寻找onnx文件
nvinfer1::ITimingCache
ITimingCache 是 TensorRT 提供的一个 引擎构建时的性能缓存接口,用来存储 层(layer)或操作的时间成本测量结果,主要作用是加速 Builder 构建网络引擎 的过程。
在 TensorRT 中,当你用 IBuilder 创建一个 ICudaEngine 时:
-
Builder 会尝试很多 不同的实现策略(不同 kernel、不同卷积算法等)。
-
对每种策略,TensorRT 会在 GPU 上 做一次性能测试(timing),选择最快的。
-
对大型网络,这个过程可能很耗时,尤其是在 FP16/INT8/不同 batch size 下。
ITimingCache 的作用就是:
-
保存这些测试结果,下次构建同样网络时直接复用,避免重复 benchmark。
-
可以保存到文件,跨程序或者跨机器(如果硬件一样)都能复用。
不过我们实际测试的时候并没有使用它,它的详细用法,也得等以后研究了。
setDynamicRange
bool ITensor::setDynamicRange(float min, float max) noexcept;
作用: 设置一个 tensor 在量化(INT8 模式)时对应的最小值与最大值范围。
TensorRT 将据此把浮点数映射为 8-bit 整数。
因为在量化模式(INT8)下,TensorRT 需要知道:
-
不同张量的数值分布;
-
才能合理映射到 8 位整数,减少精度损失。
如果不提供这些信息,TensorRT 就无法正确地做量化推理。
获取 dynamic range 的几种方式:
-
自动校准(推荐)
使用 TensorRT 的IInt8EntropyCalibrator2类自动计算每个 tensor 的动态范围。config->setFlag(BuilderFlag::kINT8); config->setInt8Calibrator(new Int8EntropyCalibrator2(...));TensorRT 会在校准数据上运行一次推理来统计动态范围。
-
手动设置
用setDynamicRange()人工指定(通常在自己写层或定制算子时用)。
samplesCommon::setAllDynamicRanges
setAllDynamicRanges() 会扫描整个网络,把所有还没有设置 dynamic range 的 tensor 自动补齐一个默认范围。
inline void setAllDynamicRanges(nvinfer1::INetworkDefinition* network, float inRange = 2.0F, float outRange = 4.0F)
{// Ensure that all layer inputs have a scale.for (int i = 0; i < network->getNbLayers(); i++){auto layer = network->getLayer(i);for (int j = 0; j < layer->getNbInputs(); j++){nvinfer1::ITensor* input{layer->getInput(j)};// Optional inputs are nullptr here and are from RNN layers.if (input != nullptr && !input->dynamicRangeIsSet()){ASSERT(input->setDynamicRange(-inRange, inRange));}}}// Ensure that all layer outputs have a scale.// Tensors that are also inputs to layers are ingored here// since the previous loop nest assigned scales to them.for (int i = 0; i < network->getNbLayers(); i++){auto layer = network->getLayer(i);for (int j = 0; j < layer->getNbOutputs(); j++){nvinfer1::ITensor* output{layer->getOutput(j)};// Optional outputs are nullptr here and are from RNN layers.if (output != nullptr && !output->dynamicRangeIsSet()){// Pooling must have the same input and output scales.if (layer->getType() == nvinfer1::LayerType::kPOOLING){ASSERT(output->setDynamicRange(-inRange, inRange));}else{ASSERT(output->setDynamicRange(-outRange, outRange));}}}}
}
遍历所有层的输入
-
获取每一层的每个输入 tensor;
-
如果这个 tensor 没有设置 dynamic range,就用
[-inRange, +inRange]来设置。
遍历所有层的输出
-
获取每一层的输出 tensor;
-
如果还没设置 dynamic range:
-
对 Pooling 层 输出,强制与输入相同范围(防止缩放比例不一致);
-
其他层输出设为
[-outRange, +outRange],默认是 [-4, 4]。
-
build
bool SampleOnnxMNIST::build(){auto builder=SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));if(!builder){//builder创建失败return false;}auto network=SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0));if(!network){return false;}auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());if (!config){return false;}auto parser=SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network,sample::gLogger.getTRTLogger()));if (!parser){return false;}auto timingCache = SampleUniquePtr<nvinfer1::ITimingCache>();auto constructed = constructNetwork(builder, network, config, parser, timingCache);if (!constructed){return false;}auto profileStream=samplesCommon::makeCudaStream();if(!profileStream){return false;}config->setProfileStream(*profileStream);//设置builder执行的流SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network,*config)};if(!plan){return false;}if(timingCache!=nullptr&&!mParams.timingCacheFile.empty()){samplesCommon::updateTimingCacheFile(sample::gLogger.getTRTLogger(),mParams.timingCacheFile,timingCache.get(),*builder);}mRuntime=std::shared_ptr<nvinfer1::IRuntime>(createInferRuntime(sample::gLogger.getTRTLogger()));if(!mRuntime){return false;}mEngine=std::shared_ptr<nvinfer1::ICudaEngine>(mRuntime->deserializeCudaEngine(plan->data(),plan->size()),samplesCommon::InferDeleter());if(!mEngine){return false;}ASSERT(network->getNbInputs()==1);mInputDims=network->getInput(0)->getDimensions();ASSERT(mInputDims.nbDims==4);//[N,C,H,W]ASSERT(network->getNbOutputs() == 1);mOutputDims = network->getOutput(0)->getDimensions();ASSERT(mOutputDims.nbDims == 2);//[N,P]输出batchsize和各类的概率return true;
}
samplesCommon::makeCudaStream
创建一个 CUDA 异步流(cudaStream_t),并用智能指针自动管理其生命周期。
template <typename T>
using SampleUniquePtr = std::unique_ptr<T>;static auto StreamDeleter = [](cudaStream_t* pStream) {if (pStream){static_cast<void>(cudaStreamDestroy(*pStream));delete pStream;}
};inline std::unique_ptr<cudaStream_t, decltype(StreamDeleter)> makeCudaStream()
{std::unique_ptr<cudaStream_t, decltype(StreamDeleter)> pStream(new cudaStream_t, StreamDeleter);if (cudaStreamCreateWithFlags(pStream.get(), cudaStreamNonBlocking) != cudaSuccess){pStream.reset(nullptr);}return pStream;
}
processInput
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers)
{const int inputH = mInputDims.d[2];const int inputW = mInputDims.d[3];// Read a random digit filesrand(unsigned(time(nullptr)));std::vector<uint8_t> fileData(inputH * inputW);mNumber = rand() % 10;//随机找mNumer.pgm这张图片进行推理readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);// Print an ascii representationsample::gLogInfo << "Input:" << std::endl;for (int i = 0; i < inputH * inputW; i++){sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");}sample::gLogInfo << std::endl;//把读入的原始图像数据(uint8_t灰度值)转换成推理所需的浮点输入格式//fileData存的是原始图像的 uint8_t 灰度像素值,范围是 [0, 255]//把它除以 255.0 后,变成 [0.0, 1.0] 范围的浮点数。//1.0 - (...) 为什么要减?//这是 MNIST 特有的一点:原始 MNIST 图像的背景是白的(像素值 255),数字是黑的(像素值 0);但很多神经网络的训练输入习惯是「前景亮、背景暗」。float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));for (int i = 0; i < inputH * inputW; i++){hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);}return true;
}
samplesCommon::BufferManager
这个类的实现源码太过复杂且多了,就不以源码形式放上来了
BufferManager 是干嘛的
一句话概括:
负责在主机端(Host)和设备端(Device)之间分配、管理、同步所有 I/O 缓冲区。
在 TensorRT 的 C++ 执行流程中,我们通常会这样走:
-
从 ONNX 读取模型 → 构建 engine;
-
用
IExecutionContext创建执行上下文; -
为模型的输入/输出分配 GPU 缓冲;
-
将输入数据从 host 拷贝到 device;
-
执行推理;
-
将结果从 device 拷贝回 host;
-
打印/验证结果。
而步骤 3、4、6 涉及大量繁琐的 cudaMalloc、cudaMemcpy 调用,
所以官方写了一个通用封装类:BufferManager。
类的主要组成结构
它内部主要管理这三类成员:
| 名称 | 类型 | 作用 |
|---|---|---|
ManagedBuffer | 包含 DeviceBuffer 和 HostBuffer | 分别管理 GPU 内存和 CPU 内存 |
mDeviceBindings | std::vector<void*> | 存放所有 GPU 缓冲区地址(传给 enqueueV3() 用) |
mNames | unordered_map<string, int> | 把 tensor 的名字映射到对应的 buffer 索引 |
每个 Tensor(不论输入还是输出)都对应一个 host/device 双缓冲。
构造函数逻辑
构造时,它会自动为所有的网络 I/O tensor 分配空间:
for (int32_t i = 0; i < mEngine->getNbIOTensors(); i++)
{auto name = engine->getIOTensorName(i);nvinfer1::DataType type = mEngine->getTensorDataType(name);size_t vol = samplesCommon::volume(dims);std::unique_ptr<ManagedBuffer> manBuf{new ManagedBuffer()};manBuf->deviceBuffer = DeviceBuffer(vol, type);manBuf->hostBuffer = HostBuffer(vol, type);mDeviceBindings.emplace_back(manBuf->deviceBuffer.data());mManagedBuffers.emplace_back(std::move(manBuf));
}
也就是说——不用你手动 cudaMalloc 了,
所有输入输出的显存和内存都会根据 tensor 形状自动分配。
执行前后拷贝数据的核心函数
两个核心接口是:
void copyInputToDevice();
void copyOutputToHost();
| 函数 | 作用 |
|---|---|
copyInputToDevice() | 把 host 端输入数据复制到 GPU 上 |
copyOutputToHost() | 把推理结束的结果从 GPU 复制回主机 |
为什么推理时要用 getDeviceBindings()
在执行推理时,例如:
context->enqueueV3(stream);
或旧接口:
context->enqueueV2(buffers.getDeviceBindings().data());
这里的参数 bindings 是一个 void** 数组,每个元素是一个 GPU tensor 的地址。
BufferManager 已经帮我们构好了这个数组,所以直接传:
buffers.getDeviceBindings().data()
即可。
否则你得手写一堆 cudaMalloc 和 cudaMemcpy。
readPGMFile
inline void readPGMFile(const std::string& fileName, uint8_t* buffer, int32_t inH, int32_t inW)
{std::ifstream infile(fileName, std::ifstream::binary);SAFE_ASSERT(infile.is_open() && "Attempting to read from a file that is not open.");std::string magic, w, h, max;infile >> magic >> w >> h >> max;infile.seekg(1, infile.cur);infile.read(reinterpret_cast<char*>(buffer), inH * inW);
}
PGM 文件的结构如下(以 ASCII P5 格式为例):
P5
28 28
255
<像素数据...>
上面 3 行分别是:
-
"P5":文件格式标识,表示灰度图(二进制格式) -
"28 28":图像宽度、高度 -
"255":像素最大值(通常是 255)
这里它通过 >> 运算符把这些头部字段都读掉。
跳过一个字节(通常是换行符):PGM 文件在头部结束后,会有一个换行符(\n)。
这一步把文件读取指针往后移动 1 字节,从而跳过这个换行符,准备读数据。
infile.seekg(1, infile.cur);
读取像素数据:
-
从文件中读取
inH × inW个字节(每个像素 1 字节) -
存到
buffer中
换句话说,这里就是把图像的灰度值(0~255)加载进内存。
比如 MNIST 的图片大小为 28×28,那么这里就会读取 784 个字节。
verifyOutput
bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
{const int outputSize = mOutputDims.d[1];//buffers.getHostBuffer() 返回模型在 host 端保存输出的内存指针。//output 是长度为 outputSize(10)的 float 数组,存储每个类别的 logits(模型未经过 softmax 的原始分数)。float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));float val{0.0F};int idx{0};// Calculate Softmaxfloat sum{0.0F};//归一化并找最大概率类别for (int i = 0; i < outputSize; i++){output[i] = exp(output[i]);sum += output[i];}sample::gLogInfo << "Output:" << std::endl;for (int i = 0; i < outputSize; i++){output[i] /= sum;val = std::max(val, output[i]);if (val == output[i]){idx = i;}sample::gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i]<< " "<< "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5F)), '*')<< std::endl;}sample::gLogInfo << std::endl;//mNumber 是输入图片的真实标签。//如果模型预测的类别 idx 和真实标签相同,并且最大概率大于 0.9,认为结果正确,返回 true。return idx == mNumber && val > 0.9F;
}
infer
bool SampleOnnxMNIST::infer(){samplesCommon::BufferManager buffers(mEngine);auto context=SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());if(!context){return false;}for(int32_t i=0,e=mEngine->getNbIOTensors();i<e;i++){const auto name=mEngine->getIOTensorName(i);context->setTensorAddress(name,buffers.getDeviceBuffer(name));}// Read the input data into the managed buffersASSERT(mParams.inputTensorNames.size() == 1);if (!processInput(buffers)){return false;}// Memcpy from host input buffers to device input buffersbuffers.copyInputToDevice();bool status = context->executeV2(buffers.getDeviceBindings().data());if (!status){return false;}// Memcpy from device output buffers to host output buffersbuffers.copyOutputToHost();// Verify resultsif (!verifyOutput(buffers)){return false;}return true;
}
initializeSampleParams
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
{samplesCommon::OnnxSampleParams params;if (args.dataDirs.empty()) // Use default directories if user hasn't provided directory paths{params.dataDirs.push_back("data/mnist/");params.dataDirs.push_back("data/samples/mnist/");}else // Use the data directory provided by the user{params.dataDirs = args.dataDirs;}params.onnxFileName = "mnist.onnx";params.inputTensorNames.push_back("Input3");params.outputTensorNames.push_back("Plus214_Output_0");params.dlaCore = args.useDLACore;params.int8 = args.runInInt8;params.fp16 = args.runInFp16;params.bf16 = args.runInBf16;params.timingCacheFile = args.timingCacheFile;return params;
}
printHelpInfo
//!
//! \brief Prints the help information for running this sample
//!
void printHelpInfo()
{std::cout<< "Usage: ./sample_onnx_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]"<< "[-t or --timingCacheFile=<path to timing cache file]" << std::endl;std::cout << "--help Display help information" << std::endl;std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used ""multiple times to add multiple directories. If no data directories are given, the default is to use ""(data/samples/mnist/, data/mnist/)"<< std::endl;std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, ""where n is the number of DLA engines on the platform."<< std::endl;std::cout << "--int8 Run in Int8 mode." << std::endl;std::cout << "--fp16 Run in FP16 mode." << std::endl;std::cout << "--bf16 Run in BF16 mode." << std::endl;std::cout << "--timingCacheFile Specify path to a timing cache file. If it does not already exist, it will be "<< "created." << std::endl;
}
main
int main(int argc, char** argv) {samplesCommon::Args args;bool argsOK = samplesCommon::parseArgs(args, argc, argv);if (!argsOK){sample::gLogError << "Invalid arguments" << std::endl;printHelpInfo();return EXIT_FAILURE;}if (args.help){printHelpInfo();return EXIT_SUCCESS;}auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);sample::gLogger.reportTestStart(sampleTest);SampleOnnxMNIST sample(initializeSampleParams(args));sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;if (!sample.build()){return sample::gLogger.reportFail(sampleTest);}if (!sample.infer()){return sample::gLogger.reportFail(sampleTest);}return sample::gLogger.reportPass(sampleTest);
}
1. 解析命令行参数
-
samplesCommon::Args是 TensorRT 示例通用参数结构,包含输入 ONNX 文件路径、batch 大小、工作目录等信息。 -
parseArgs()从命令行读取这些参数 -
如果参数错误或用户加了
--help,就打印帮助信息并退出。
2. 初始化日志系统
-
TensorRT 官方示例自带一个日志框架(
sample::gLogger)。 -
defineTest会注册一次测试信息(例如测试名、参数等)。 -
reportTestStart表示“测试开始”。 -
后面还会调用
reportPass()或reportFail()来输出最终状态。
3. 创建推理样例对象
-
SampleOnnxMNIST是官方定义的类,封装了:-
构建 TensorRT Engine (
build()), -
执行推理 (
infer()), -
验证输出 (
verifyOutput()), -
加载输入图像 (
processInput()).
-
-
initializeSampleParams(args)会把命令行参数打包成结构体传入,比如输入文件路径、输入尺寸、输出 tensor 名称等。
4. 打印信息并构建引擎
5. 执行推理
6. 报告结果
samplesCommon::Args
struct Args
{bool runInInt8{false};bool runInFp16{false};bool runInBf16{false};bool help{false};int32_t useDLACore{-1};int32_t batch{1};std::vector<std::string> dataDirs;std::string saveEngine;std::string loadEngine;bool rowOrder{true};std::string timingCacheFile;
};
| 字段 | 含义 |
|---|---|
std::vector<std::string> dataDirs | 输入图片所在目录,例如 --datadir=./data/mnist/. |
std::string saveEngine | 保存 .engine 文件路径,例如 --saveEngine=mnist_fp16.engine. |
std::string loadEngine | 加载已有的 .engine 文件(跳过重新构建),例如 --loadEngine=mnist_fp16.engine. |
std::string timingCacheFile | 指定 TensorRT timing cache 文件路径,用于保存 layer 优化的性能记录,加快之后的 engine 构建。 |
运行方式
官方给的运行方式如下
./sample_onnx_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>] [--int8 or --fp16]
但是问题是我们需要在datadir文件夹下准备好onnx模型和输入图片
由于官方的输入都是一张图片,所以我们导出模型的时候切记输入形状是[1,1,28,28],而且输入输出的名字官方也定死了,所以要么修改一下上面的案例,要么只能自行导出
导出onnx模型
下面给出最简单的代码示例,注意我这个导出的模型并没有训练,所以如果你直接使用,是达不到要求的90%正确率的。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleMNIST(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 16, 3, 1)self.conv2 = nn.Conv2d(16, 32, 3, 1)self.fc1 = nn.Linear(32 * 12 * 12, 64) self.fc2 = nn.Linear(64, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = self.fc2(x)return x
model = SimpleMNIST()
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "mnist.onnx",input_names=["Input3"], output_names=["Plus214_Output_0"])
准备输入的图片
示例里设定好了输入的图片必须是pgm形式的MNIST数据集的图片,而且图片名字必须是数字.pgm,比如1.pgm
可以按照如下方式生成
from torchvision import datasets
from PIL import Imagemnist = datasets.MNIST('.', download=True)
for i in range(10):img, label = mnist[i]img.save(f"{label}.pgm")
注意mnist前10张图片的label有重复,所以可能刚好要读取的不是这个label的图片
然后把模型和图片的目录地址通过参数传进去即可
