当前位置: 首页 > news >正文

cap5:YoloV5分割任务的TensorRT部署指南(python版)

《TensorRT全流程部署指南》专栏文章目录:

  • cap1:TensorRT介绍及CUDA环境安装
  • cap2:1000分类的ResNet的TensorRT部署指南(python版)
  • cap3:自定义数据集训练ResNet的TensorRT部署指南(python版)
  • cap4:YoloV5目标检测任务的TensorRT部署指南(python版)
  • cap5:YoloV5分割任务的TensorRT部署指南(python版)

文章目录

  • 1、获取pth模型数据
  • 2、导出ONNX
    • 2.1 直接导出
    • 2.2 测试
  • 3、TensoRT环境搭建
  • 4、转换TensorRT引擎
    • 4.1 使用trtexec工具完成序列化
    • 4.2 使用python的API进行转换
  • 5、推理
    • 5.1 完整代码

在前几章中,我们详细探讨了如何将分类模型ResNet和目标检测模型YOLOv5部署到TensorRT上,相信大家对TensorRT的模型部署已经有了初步的了解。接下来,本文将带领你进一步深入,探索如何将YOLOv5分割模型成功部署到TensorRT。

分割模型的部署流程与之前的分类和目标检测模型相似,唯一不同的是后处理会相对复杂一些。不过,值得注意的是,TensorRT的推理过程仅涉及模型推理,模型输出后的后处理任务由Python完成,和TensorRT无关。对于有一定经验的你来说,这些步骤将会变得轻松易掌握。接下来,我们将逐步解析这些关键环节,帮助你顺利掌握YOLOv5分割任务在TensorRT上的部署技巧。

1、获取pth模型数据

yolov5项目的官方地址在ultralytics/yolov5,这个项目提供了完整的yolo训练框架。使用这个框架可以训练出自己的检测模型。为了方便展示,本文直接使用官方提供的yolov5s-seg.pt预训练模型,地址在:Segmentation。官方的预训练模型的输入是(1,3,640,640)。

2、导出ONNX

2.1 直接导出

实际上导出onnx的脚本在ultralytics/yolov5/export.py中提供了,通过命令可以快速导出onnx模型:

python export.py --weights=yolov5s-seg.pt --imgsz=[640,640] --include==onnx

通过以上命令就可以导出onnx模型。export.py还可以导出更多类型,并定制一些需求,可以浏览export.py学习更多。

2.2 测试

首先,我们通过netron可视化onnx模型。可以看到该分割模型的输入为[1,3,640,640],有两个输出:output0output1
在这里插入图片描述
为了方便后续进行后处理(虽然和TensorRT无关),需要先弄清楚输出数据的含义。可以看到共计有25200个检测框,每个检测框有117的属性:4个bbox数据,1个bbox置信度,80个类别得分,32个mask分割掩码,共计117个值。

output1其实是原型掩码(protos),产生用于分割模型的原型掩码。

通俗来讲,原型掩码相当于最基础的32组合,32个mask就是决定使用哪些原型掩码并组合起来就得到当前bbox的分割结果。

3、TensoRT环境搭建

参考《cap2:1000分类的ResNet的TensorRT部署指南(python版)》中的3、环境搭建部分。

4、转换TensorRT引擎

这一步主要完成红框内的工作,就是将onnx模型序列化为TensorRT的engine模型。在这个过程中会自动根据设备的特异性进行优化,该过程十分缓慢。但是完成后会保存为本地文件,下次可以直接加载使用,不用再次序列化。
在这里插入图片描述
这个过程可以使用基于python的API完成,或者直接使用trtexec工具。

4.1 使用trtexec工具完成序列化

在第3节中下载了TensorRT包,在bin中有trtexec工具。打开终端进入trtexec工具所在文件夹就可以使用该命令工具了。
在这里插入图片描述

转换命令为:
trtexec --onnx=yolov5s-seg.onnx --saveEngine=yolov5s-seg.engine --fp16
onnx指定onnx模型路径,saveEngine知道转换后engine模保存路径,fp16表示使用FP16,这个可以大大提高推理速度。不加–fp16则默认使用FLOAT32

最后会出现PASSED结果,说明成功了,此时就得到engine模型。

4.2 使用python的API进行转换

使用python的API进行转换具有比较固定的流程,只不过可以根据需求比如动态输入等进行相应的设置。最终结果和trtexec转换没有什么区别。

不管什么方式转换得到的engine都可以被python或者c++拿去部署使用,但是要注意:1)版本对应:即转换使用的TensorRT版本和部署推理使用的版本需要一致;2)不支持跨设备:即在哪种设备上转换的就只能在哪种设备上使用,因为转换时根据设备特异性进行了优化。

import tensorrt as trt
# import pycuda.driver as cuda
# import pycuda.autoinit

# TensorRT 需要一个 日志对象,用于输出 WARNING 级别的日志,帮助调试问题
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_engine(onnx_file_path, engine_file_path):
    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        # config 配置 TensorRT 编译选项
        builder_config = builder.create_builder_config()
        builder_config.set_flag(trt.BuilderFlag.FP16)   # 设置FP16加速(如果 GPU 支持),提高计算速度并减少显存占用。注释则默认使用FP32
        builder_config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 设置最大工作空间(1GB),用于存储中间 Tensor 和计算资源

        # 读取 ONNX 文件
        with open(onnx_file_path, 'rb') as model:
            if not parser.parse(model.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None

        # 构建 TensorRT 引擎
        serialized_engine = builder.build_serialized_network(network, builder_config)

        # 开始反序列化为可执行引擎
        runtime = trt.Runtime(TRT_LOGGER)
        engine = runtime.deserialize_cuda_engine(serialized_engine)

        # 序列化并保存引擎
        with open(engine_file_path, 'wb') as f:
            f.write(engine.serialize())
        print(f"Saved TensorRT engine to {engine_file_path}")

if __name__ == "__main__":
    onnx_file = "yolov5s-seg.onnx"
    engine_file = "yolov5s-seg.engine"
    build_engine(onnx_file, engine_file)

通过这种方式,我们也成功得到的engine模型。

5、推理

5.1 完整代码

推理代码也十分简单,先给出全部代码:

# 即将更新

推理代码其实非常简单,就是TensorRT部署的代码+python实现的预处理+python实现的后处理。只不过后处理比较麻烦,占了绝大多数的行数而已。其实只要掌握TensorRT推理的核心代码即可。

相关文章:

  • Arkts和Typescript语法上差别
  • CNAPPgoat:一款针对云环境的安全实践靶场
  • 计算机网络(3)TCP格式/连接
  • 扩散模型中的马尔可夫链设计演进:从DDPM到Stable Diffusion全解析
  • 【原创】在ubuntu中搭建gradle开发环境
  • 网工项目理论1.12 高可用性设计
  • 机舱卫生和空气质量改善
  • GUI编程一:相关概念及重要知识
  • 若依Flowable工作流版本监听器使用方法
  • CPP集群聊天服务器开发实践(七):Github上传项目
  • SpringBoot+Vue+数据可视化的动漫妆造服务平台(程序+论文+讲解+安装+调试+售后等)
  • Office word打开加载比较慢处理方法
  • Vue.js 组件开发:构建可复用的 UI 组件
  • KVM设置端口转发
  • SpringMVC重定向接口,参数暴露在url中解决方案!RedirectAttributes
  • 2025年人工智能十大趋势:AI如何塑造未来?
  • asp.net core mvc 富文本编辑的实现
  • matlab 汽车abs的模糊pid和pid控制仿真
  • 美国股市主要指数介绍(Major U.S. Stock Market Indexes):三大股指(中英双语)
  • ubuntu安装docker 无法拉取问题
  • wordpress admin plugin/广州seo推荐
  • 网站建设公司软文/百度关键词点击排名
  • jsp的网站/郑州企业网站优化排名
  • wordpress排版教程/抖音seo优化怎么做
  • 宿州哪有做网站的/b站推广形式
  • 建立应用网站/市场营销方案怎么写