当前位置: 首页 > news >正文

Python实现ONNXRuntime推理YOLOv11模型

Python实现ONNXRuntime推理YOLOv11模型,主要在于onnxruntime推理后的后处理部分

1、安装依赖

pip install opencv-python onnxruntime numpy

2、ONNX模型导出(可选)

from ultralytics import YOLO
 
# Load a model
model = YOLO("yolo11n.pt")  # load an official model
model = YOLO("best.pt")  # load a custom trained model
 
# Export the model
model.export(format="onnx", simplify=True, half=True)

3、Python推理代码

import cv2
import numpy as np
import onnxruntime as ort
from math import exp
 
# 常量配置
CLASSES = ['class1']  # 模型类别

np.random.seed(1)
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))  # 随机颜色
meshgrid = []
class_num = len(CLASSES)
headNum = 3
strides = [8, 16, 32]
mapSize = [[80, 80], [40, 40], [20, 20]]
input_imgH = 640
input_imgW = 640
 
 
class DetectBox:
    """检测框类"""
    def __init__(self, classId, score, xmin, ymin, xmax, ymax):
        self.classId = classId
        self.score = score
        self.xmin = xmin
        self.ymin = ymin
        self.xmax = xmax
        self.ymax = ymax
 

 
class YOLODetector:
    def __init__(self, model_path='./yolov11n.onnx', conf_thresh=0.5, iou_thresh=0.45):
        self.model_path = model_path
        self.conf_thresh = conf_thresh
        self.iou_thresh = iou_thresh
        self.ort_session = ort.InferenceSession(self.model_path)
        self.generate_meshgrid()
 
    @staticmethod
    def sigmoid(x):
        return 1 / (1 + exp(-x))
 
    @staticmethod
    def preprocess_image(img_src, resize_w, resize_h):
        image = cv2.resize(img_src, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype(np.float32)
        image /= 255.0
        return image
 
    def generate_meshgrid(self):
        for index in range(headNum):
            for i in range(mapSize[index][0]):
                for j in range(mapSize[index][1]):
                    meshgrid.append(j + 0.5)
                    meshgrid.append(i + 0.5)
 
    def iou(self, xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2):
        xmin = max(xmin1, xmin2)
        ymin = max(ymin1, ymin2)
        xmax = min(xmax1, xmax2)
        ymax = min(ymax1, ymax2)
 
        innerWidth = max(0, xmax - xmin)
        innerHeight = max(0, ymax - ymin)
 
        innerArea = innerWidth * innerHeight
        area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
        area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
        total = area1 + area2 - innerArea
 
        return innerArea / total
 
    def nms(self, detectResult):
        predBoxs = []
        sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True)
 
        for i in range(len(sort_detectboxs)):
            if sort_detectboxs[i].classId != -1:
                predBoxs.append(sort_detectboxs[i])
                for j in range(i + 1, len(sort_detectboxs), 1):
                    if sort_detectboxs[i].classId == sort_detectboxs[j].classId:
                        iou = self.iou(
                            sort_detectboxs[i].xmin, sort_detectboxs[i].ymin,
                            sort_detectboxs[i].xmax, sort_detectboxs[i].ymax,
                            sort_detectboxs[j].xmin, sort_detectboxs[j].ymin,
                            sort_detectboxs[j].xmax, sort_detectboxs[j].ymax
                        )
                        if iou > self.iou_thresh:
                            sort_detectboxs[j].classId = -1
        return predBoxs
 
    def postprocess(self, out, img_h, img_w):
        detectResult = []
        output = out[0][0]  # 去掉 batch 维度,变为 (5 + num_classes, num_boxes)
 
        # 提取预测框信息
        reg = output[0:4, :]  # 回归框的 x, y, w, h
        conf = output[4, :]  # 置信度
 
        # 检查是否是多类别模型
        if output.shape[0] > 5:  # 如果输出维度大于 5,说明是多类别模型
            class_probs = output[5:, :]  # 类别概率 (num_classes, num_boxes)
            is_multiclass = True
        else:
            is_multiclass = False
 
        scale_h = img_h / input_imgH
        scale_w = img_w / input_imgW
 
        for i in range(reg.shape[1]):  # 遍历所有预测框
            x, y, w, h = reg[:, i]
            score = self.sigmoid(conf[i])  # 使用 sigmoid 激活置信度
            # score = float(conf[i])  # 使用 sigmoid 激活置信度
 
            if is_multiclass:  # 多类别模型
                # 使用 softmax 激活类别概率
                class_prob = np.exp(class_probs[:, i]) / np.sum(np.exp(class_probs[:, i]))
                class_id = np.argmax(class_prob)  # 获取最大概率的类别索引
                class_score = class_prob[class_id]  # 获取该类别的概率
 
                # 综合置信度和类别概率
                final_score = score * class_score
            else:  # 单类别模型
                class_id = 0  # 单类别情况下,类别索引固定为 0
                final_score = score  # 置信度即为最终得分
 
            if final_score > self.conf_thresh:  # 过滤低置信度框
                xmin = max(0, (x - w / 2) * scale_w)
                ymin = max(0, (y - h / 2) * scale_h)
                xmax = min(img_w, (x + w / 2) * scale_w)
                ymax = min(img_h, (y + h / 2) * scale_h)
 
                # 添加框信息
                box = DetectBox(classId=class_id, score=final_score, xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
                detectResult.append(box)
 
        predBox = self.nms(detectResult)  # 非极大值抑制
        return predBox
 
    def detect(self, img_path):
        if isinstance(img_path, str):
            orig = cv2.imread(img_path)
        else:
            orig = img_path
 
        img_h, img_w = orig.shape[:2]
        image = self.preprocess_image(orig, input_imgW, input_imgH)
        image = image.transpose((2, 0, 1))
        image = np.expand_dims(image, axis=0)
 
        pred_results = self.ort_session.run(None, {'images': image})
 
        # 打印模型输出形状(调试用)
        print(f"Model output shape: {pred_results[0].shape}")
 
        predbox = self.postprocess(pred_results, img_h, img_w)
 
        boxes = []
        scores = []
        class_ids = []
 
        for box in predbox:
            boxes.append([int(box.xmin), int(box.ymin), int(box.xmax), int(box.ymax)])
            scores.append(box.score)
            class_ids.append(box.classId)
 
        return boxes, scores, class_ids
 
    def draw_detections(self, image, boxes, scores, class_ids, mask_alpha=0.3):
        """
        Combines drawing masks, boxes, and text annotations on detected objects.
        Parameters:
        - image: Input image.
        - boxes: Array of bounding boxes.
        - scores: Confidence scores for each detected object.
        - class_ids: Detected object class IDs.
        - mask_alpha: Transparency of the mask overlay.
        """
        det_img = image.copy()
 
        img_height, img_width = image.shape[:2]
        font_size = min([img_height, img_width]) * 0.001
        text_thickness = int(min([img_height, img_width]) * 0.001)
 
        mask_img = image.copy()
 
        # Draw bounding boxes, masks, and text annotations
        for class_id, box, score in zip(class_ids, boxes, scores):
            color = COLORS[class_id]
            x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
 
            # Draw fill rectangle for mask
            cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
 
            # Draw bounding box
            cv2.rectangle(det_img, (x1, y1), (x2, y2), color, 2)
 
            # Prepare text (label and score)
            label = CLASSES[class_id]
            caption = f'{label} {int(score * 100)}%'
 
            # Calculate text size and position
            (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                                          fontScale=font_size, thickness=text_thickness)
            th = int(th * 1.2)
 
            # Draw filled rectangle for text background
            cv2.rectangle(det_img, (x1, y1), (x1 + tw, y1 - th), color, -1)
 
            # Draw text over the filled rectangle
            cv2.putText(det_img, caption, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, font_size,
                        (255, 255, 255), text_thickness, cv2.LINE_AA)
 
        # Blend the mask image with the original image
        det_img = cv2.addWeighted(mask_img, mask_alpha, det_img, 1 - mask_alpha, 0)
 
        return det_img
    
if __name__ == "__main__":
    
    model_path = './models/epoch_100_N.onnx'
    img_path = "/mnt/d/data/001/SD-001-00920/SI-001-01708.png"
 
    detector = YOLODetector(model_path=model_path, conf_thresh=0.55, iou_thresh=0.7)
 
    boxes, scores, class_ids = detector.detect(img_path)
    print("检测到的框:", boxes)
    print("检测到的分数:", scores)
    print("检测到的类别:", class_ids)
 
    image = cv2.imread(img_path)
    result_img = detector.draw_detections(image, boxes, scores, class_ids)
    cv2.imwrite('result.jpg', result_img)
    cv2.imshow('Detection Results', result_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

4、yolo11架构图

在这里插入图片描述

相关文章:

  • AI 如何重塑数据湖的未来
  • git原理与常用命令及其使用
  • 数学建模:MATLAB卷积神经网络
  • 【嵌入式学习】触发器 - ADC - DAC
  • 微信 MMTLS 协议详解(五):加密实现
  • 【嵌入式硬件测试之道连载之第三章:核心处理器的选型与应用】
  • IS-IS原理与配置
  • Nexus L2 L3基本配置
  • 【Java SE】抽象类/方法、模板设计模式
  • 【递归,搜索与回溯算法篇】- 名词解释
  • 从X光片生成合成计算机断层扫描(CT)样成像的策略:一项范围审查|文献速递-医学影像人工智能进展
  • 【C++】sort函数的两种用法
  • 分布式容器技术是什么
  • 解决python配置文件类configparser.ConfigParser,插入、读取数据,自动转为小写的问题
  • AGI成立的条件
  • 算法及数据结构系列 - 回溯算法
  • 嵌入式芯片与系统设计竞赛,值得参加吗?如何选题?需要学什么?怎么准备?
  • QT开发(4)--各种方式实现HelloWorld
  • centos 7 搭建FTP user-list用户列表
  • LeetCode算法题(Go语言实现)_07
  • 伤员回归新援融入,海港逆转海牛重回争冠集团
  • 经济日报刊文:品牌经营不能让情怀唱“独角戏”
  • 工程院院士葛世荣获聘任为江西理工大学校长
  • 《中国人民银行业务领域数据安全管理办法》发布,6月30日起施行
  • 七方面118项任务,2025年知识产权强国建设推进计划印发
  • 陕西澄城樱桃在上海推介,向长三角消费者发出“甜蜜之邀”