当前位置: 首页 > news >正文

NMS代码详解(数据维度变换解析)

NMS(非极大值抑制)是目标检测中不可或缺的后处理步骤,用于消除重复检测框,提升检测精度。本文将深入解析NMS代码,重点剖析数据维度变换的关键细节。

先看完整代码

def nms(pred, conf_thres, iou_thres):  """  非极大值抑制nms  Args:  pred: 模型输出特征图  conf_thres: 置信度阈值  iou_thres: iou阈值  Returns: 输出后的结果  """  # 筛选出置信度大于阈值的框  box = pred[pred[..., 4] > conf_thres]  # 置信度筛选  cls_conf = box[..., 5:]  # 提取每个框的类别置信度  cls = []  # 存储每个框的类别索引  # 为每个框确定类别  for i in range(len(cls_conf)):  cls.append(int(np.argmax(cls_conf[i])))  # 找到最大类别置信度的索引  # 记录图像内共出现的不同物体类别  total_cls = list(set(cls))  output_box = []  # 存储最终输出的框  # 针对每个预测类别进行处理  for i in range(len(total_cls)):  clss = total_cls[i]  # 当前处理的类别  cls_box = []  # 存储当前类别的框  temp = box[:, :6]  # 提取框的前六个元素[x, y, w, h, conf, class1_conf]  # 筛选出当前类别的框  for j in range(len(cls)):  # 记录[x, y, w, h, conf, class1_conf]值  if cls[j] == clss:  temp[j][5] = clss  # 将类别索引赋值给框  cls_box.append(temp[j][:6])  # 添加到当前类别的框列表中  # cls_box 里面是[x, y, w, h, conf, class_id]  cls_box = np.array(cls_box)  # 转换为NumPy数组  # 将cls_box按置信度从大到小排序  sort_cls_box = sorted(cls_box, key=lambda x: -x[4])  # 得到置信度最大的预测框  max_conf_box = sort_cls_box[0]  output_box.append(max_conf_box)  # 添加到输出框列表中  sort_cls_box = np.delete(sort_cls_box, 0, 0)  # 删除已处理的框  # 对除max_conf_box外其他的框进行非极大值抑制  while len(sort_cls_box) > 0:  # 得到当前最大的框  max_conf_box = output_box[-1]  del_index = []  # 用于存储需要删除的框的索引  # 遍历剩余的框  for j in range(len(sort_cls_box)):  current_box = sort_cls_box[j]  # 当前框  # 计算当前框与最大框的IoU  iou = get_iou(max_conf_box, current_box)  if iou > iou_thres:  # 筛选出与当前最大框IoU大于阈值的框的索引  del_index.append(j)  # 删除这些索引  sort_cls_box = np.delete(sort_cls_box, del_index, 0)  # 如果还有剩余的框,继续处理  if len(sort_cls_box) > 0:  output_box.append(sort_cls_box[0])  # 添加下一个置信度最高的框  sort_cls_box = np.delete(sort_cls_box, 0, 0)  # 删除已处理的框  return output_box  # 返回最终输出的框列表

开始解析

1. 输入

def nms(pred, conf_thres, iou_thres):
  • pred: 模型的输出特征图,通常是一个二维数组。每个元素包含预测框的坐标、置信度和类别信息。具体格式为 [x, y, w, h, conf, class1_conf, class2_conf, ...],其中:

    • (x, y) 是框的中心坐标。
    • (w, h) 是框的宽度和高度。
    • conf 是该框的置信度(通常是“此框内有目标”的概率)。
    • class_conf 是每个类别的置信度。
  • conf_thres: 置信度阈值,用于筛选出较低置信度的预测框。

  • iou_thres: IoU(交并比)阈值,用于判断两个框是否重叠。

2. 置信度筛选

box = pred[pred[..., 4] > conf_thres]

这行代码会筛选出置信度大于 conf_thres 的预测框。pred[..., 4] 取出所有框的置信度。

维度变化

  • 原本 pred.shape = (N, 5+K)

  • 过滤后 box.shape = (N', 5+K),其中 N' ≤ N

后面 K 列是每个类别的置信度。

3. 提取类别置信度并确定每个框的“最佳类别”

cls_conf = box[..., 5:]  # 提取每个框的类别置信度  
cls = []  # 存储每个框的类别索引  # 为每个框确定类别  
for i in range(len(cls_conf)):  cls.append(int(np.argmax(cls_conf[i])))  # 找到最大类别置信度的索引  
  • cls_conf 去掉前 5 列,留下 每个框对所有类别的得分

  • 通过 np.argmax 找到每个框的最大类别置信度,并将其对应的类别索引存入 cls 列表中

  • 维度cls 长度 = N',与 box 行数一一对应。

4.分类别做 NMS

4.1统计图像中出现的类别集合并处理这些类别的框

# 记录图像内共出现的不同物体类别  
total_cls = list(set(cls))  
output_box = []  # 存储最终输出的框  # 针对每个预测类别进行处理  
for i in range(len(total_cls)):  clss = total_cls[i]  # 当前处理的类别  cls_box = []  # 存储当前类别的框  temp = box[:, :6]  # 提取框的前六个元素[x, y, w, h, conf, class1_conf]  
  • total_cls使用集合去重,得到无重复的类别 id,例如 [0, 2, 5]
  • 对于每个类别,创建一个空列表 cls_box 来存储该类别的预测框。
  • temp先复制出前 6 列([x, y, w, h, conf, class1_conf]),最后一列之后将被改成类别 id

4.2 抽取该类的所有框并按置信度排序

# 筛选出当前类别的框  
for j in range(len(cls)):  # 记录[x, y, w, h, conf, class1_conf]值  if cls[j] == clss:  temp[j][5] = clss  # 将类别索引赋值给框  cls_box.append(temp[j][:6])  # 添加到当前类别的框列表中  # cls_box 里面是[x, y, w, h, conf, class_id]  
cls_box = np.array(cls_box)  # 转换为NumPy数组  
# 将cls_box按置信度从大到小排序  
sort_cls_box = sorted(cls_box, key=lambda x: -x[4])  
  • 遍历 cls 列表,筛选出当前类别的框,并将其添加到 cls_box 中。把第 6 列改成类别idcls_box 最终变成 (M, 6)[x, y, w, h, conf, class_id]
  • cls_box 转换为 NumPy 数组后,根据置信度降序排序。

4.3 逐步抑制 IOU 大的框

# 从按置信度排序后的框中取出置信度最高的框  
max_conf_box = sort_cls_box[0]  
# 将该框添加到输出列表中  
output_box.append(max_conf_box)  
# 从排序列表中删除该框,以便后续处理  
sort_cls_box = np.delete(sort_cls_box, 0, 0)  # 当排序后的框列表不为空时,继续进行处理  
while len(sort_cls_box) > 0:  # 取出当前输出列表中最后添加的框(即当前的最大置信度框)  max_conf_box = output_box[-1]  # 用于记录需要删除的框的索引  del_index = []  # 遍历剩余的框,计算它们与当前最大框的IoU  for j in range(len(sort_cls_box)):  current_box = sort_cls_box[j]  # 当前框  # 计算当前框与最大框之间的IoU  iou = get_iou(max_conf_box, current_box)  # 如果IoU大于设定的阈值,认为这两个框重叠过多  if iou > iou_thres:  # 将该框的索引添加到待删除列表中  del_index.append(j)  # 从排序框列表中删除重叠过多的框  sort_cls_box = np.delete(sort_cls_box, del_index, 0)  # 如果还有剩余的框,继续处理  if len(sort_cls_box) > 0:  # 取出下一个置信度最高的框并添加到输出列表中  output_box.append(sort_cls_box[0])  # 从排序列表中删除该框,以便后续处理  sort_cls_box = np.delete(sort_cls_box, 0, 0)

核心逻辑:

  1. 取当前剩余列表的 第一项(最高置信度)放到 output_box

  2. 把它与列表中其他框逐一计算 IoU:

    iou > iou_thres ⇒ 认为“与最大框重叠太多”,删除该框索引。
  3. 列表删完这些框后,如果还剩东西,就再取新的首项,重复步骤 2。

  4. 直到该类别框列表清空。

4.4 输出

return output_box
  • 列表形式,每个元素是长度 6 的 numpy.ndarray[x, y, w, h, conf, class_id]

  • 含义:对每个类别做完 NMS 后 剩下的保留框

  • 输出的xywh(框的中心坐标和宽高)通常是相对于模型输入的缩放尺寸(例如640x640)进行预测的实际像素值。若想直接送到绘图、评估代码,还需要进行一些转换处理,相关处理的文章可以看YOLOv8 预测结果添加面积过滤以及检测框坐标如何是从针对缩放转换为针对原图大小-CSDN博客

http://www.dtcms.com/a/283209.html

相关文章:

  • 格密码--Ring-SIS和Ring-LWE
  • 架构解密|一步步打造高可用的 JOCR OCR 识别服务
  • oracle会话控制和存储状态查询
  • pyqt当中splitter.setSizes()不生效
  • C++中vector和list的优缺点对比以及deque
  • PowerJob集群机器数为0问题
  • Python第八章作业(初级)
  • 如何使用VScode使用ssh连接远程服务器不需要输入密码直接登录
  • 27.Hamming 距离
  • transformers基础Data Collator
  • 教程:如何快速查询 A 股实时 K线和5档盘口
  • 今日行情明日机会——20250716
  • Redis深度解析:从缓存到分布式系统的核心引擎
  • 用python实现自动化布尔盲注
  • pytest--1--pytest-mock常用的方法
  • 代码随想录day36dp4
  • 震坤行获取商品SKU操作详解
  • 16路串口光纤通信FPGA项目实现指南
  • Kotlin获取集合中的元素操作
  • Java与Vue精心打造资产设备管理系统,提供源码,适配移动端与后台管理,助力企业高效掌控资产动态,提升管理效能
  • 【Java】JUC并发(synchronized进阶、ReentrantLock可重入锁)
  • 二重循环:输入行数,打印直角三角形和倒直角三角形
  • Java后端开发核心笔记:分层架构、注解与面向对象精髓
  • 基于Android的旅游计划App
  • Web基础 -MYSQL
  • 冷库耗电高的原因,冷链运营者的降本增效的方法
  • LVS四种模式及部署NAT、DR模式集群
  • CD53.【C++ Dev】模拟实现优先级队列(含仿函数)
  • 【计算机网络】数据通讯第二章 - 应用层
  • 深度学习之反向传播