NMS代码详解(数据维度变换解析)
NMS(非极大值抑制)是目标检测中不可或缺的后处理步骤,用于消除重复检测框,提升检测精度。本文将深入解析NMS代码,重点剖析数据维度变换的关键细节。
先看完整代码
def nms(pred, conf_thres, iou_thres): """ 非极大值抑制nms Args: pred: 模型输出特征图 conf_thres: 置信度阈值 iou_thres: iou阈值 Returns: 输出后的结果 """ # 筛选出置信度大于阈值的框 box = pred[pred[..., 4] > conf_thres] # 置信度筛选 cls_conf = box[..., 5:] # 提取每个框的类别置信度 cls = [] # 存储每个框的类别索引 # 为每个框确定类别 for i in range(len(cls_conf)): cls.append(int(np.argmax(cls_conf[i]))) # 找到最大类别置信度的索引 # 记录图像内共出现的不同物体类别 total_cls = list(set(cls)) output_box = [] # 存储最终输出的框 # 针对每个预测类别进行处理 for i in range(len(total_cls)): clss = total_cls[i] # 当前处理的类别 cls_box = [] # 存储当前类别的框 temp = box[:, :6] # 提取框的前六个元素[x, y, w, h, conf, class1_conf] # 筛选出当前类别的框 for j in range(len(cls)): # 记录[x, y, w, h, conf, class1_conf]值 if cls[j] == clss: temp[j][5] = clss # 将类别索引赋值给框 cls_box.append(temp[j][:6]) # 添加到当前类别的框列表中 # cls_box 里面是[x, y, w, h, conf, class_id] cls_box = np.array(cls_box) # 转换为NumPy数组 # 将cls_box按置信度从大到小排序 sort_cls_box = sorted(cls_box, key=lambda x: -x[4]) # 得到置信度最大的预测框 max_conf_box = sort_cls_box[0] output_box.append(max_conf_box) # 添加到输出框列表中 sort_cls_box = np.delete(sort_cls_box, 0, 0) # 删除已处理的框 # 对除max_conf_box外其他的框进行非极大值抑制 while len(sort_cls_box) > 0: # 得到当前最大的框 max_conf_box = output_box[-1] del_index = [] # 用于存储需要删除的框的索引 # 遍历剩余的框 for j in range(len(sort_cls_box)): current_box = sort_cls_box[j] # 当前框 # 计算当前框与最大框的IoU iou = get_iou(max_conf_box, current_box) if iou > iou_thres: # 筛选出与当前最大框IoU大于阈值的框的索引 del_index.append(j) # 删除这些索引 sort_cls_box = np.delete(sort_cls_box, del_index, 0) # 如果还有剩余的框,继续处理 if len(sort_cls_box) > 0: output_box.append(sort_cls_box[0]) # 添加下一个置信度最高的框 sort_cls_box = np.delete(sort_cls_box, 0, 0) # 删除已处理的框 return output_box # 返回最终输出的框列表
开始解析
1. 输入
def nms(pred, conf_thres, iou_thres):
-
pred: 模型的输出特征图,通常是一个二维数组。每个元素包含预测框的坐标、置信度和类别信息。具体格式为
[x, y, w, h, conf, class1_conf, class2_conf, ...]
,其中:(x, y)
是框的中心坐标。(w, h)
是框的宽度和高度。conf
是该框的置信度(通常是“此框内有目标”的概率)。class_conf
是每个类别的置信度。
-
conf_thres: 置信度阈值,用于筛选出较低置信度的预测框。
-
iou_thres: IoU(交并比)阈值,用于判断两个框是否重叠。
2. 置信度筛选
box = pred[pred[..., 4] > conf_thres]
这行代码会筛选出置信度大于 conf_thres
的预测框。pred[..., 4]
取出所有框的置信度。
维度变化:
-
原本
pred.shape = (N, 5+K)
-
过滤后
box.shape = (N', 5+K)
,其中N' ≤ N
后面 K
列是每个类别的置信度。
3. 提取类别置信度并确定每个框的“最佳类别”
cls_conf = box[..., 5:] # 提取每个框的类别置信度
cls = [] # 存储每个框的类别索引 # 为每个框确定类别
for i in range(len(cls_conf)): cls.append(int(np.argmax(cls_conf[i]))) # 找到最大类别置信度的索引
-
cls_conf
去掉前 5 列,留下 每个框对所有类别的得分。 -
通过
np.argmax
找到每个框的最大类别置信度,并将其对应的类别索引存入cls
列表中 -
维度:
cls
长度 =N'
,与box
行数一一对应。
4.分类别做 NMS
4.1统计图像中出现的类别集合并处理这些类别的框
# 记录图像内共出现的不同物体类别
total_cls = list(set(cls))
output_box = [] # 存储最终输出的框 # 针对每个预测类别进行处理
for i in range(len(total_cls)): clss = total_cls[i] # 当前处理的类别 cls_box = [] # 存储当前类别的框 temp = box[:, :6] # 提取框的前六个元素[x, y, w, h, conf, class1_conf]
- total_cls使用集合去重,得到无重复的类别 id,例如
[0, 2, 5]
。 - 对于每个类别,创建一个空列表
cls_box
来存储该类别的预测框。 temp
先复制出前 6 列([x, y, w, h, conf, class1_conf]
),最后一列之后将被改成类别 id。
4.2 抽取该类的所有框并按置信度排序
# 筛选出当前类别的框
for j in range(len(cls)): # 记录[x, y, w, h, conf, class1_conf]值 if cls[j] == clss: temp[j][5] = clss # 将类别索引赋值给框 cls_box.append(temp[j][:6]) # 添加到当前类别的框列表中 # cls_box 里面是[x, y, w, h, conf, class_id]
cls_box = np.array(cls_box) # 转换为NumPy数组
# 将cls_box按置信度从大到小排序
sort_cls_box = sorted(cls_box, key=lambda x: -x[4])
- 遍历
cls
列表,筛选出当前类别的框,并将其添加到cls_box
中。把第 6 列改成类别id
,cls_box
最终变成(M, 6)
:[x, y, w, h, conf, class_id]
。 cls_box
转换为 NumPy 数组后,根据置信度降序排序。
4.3 逐步抑制 IOU 大的框
# 从按置信度排序后的框中取出置信度最高的框
max_conf_box = sort_cls_box[0]
# 将该框添加到输出列表中
output_box.append(max_conf_box)
# 从排序列表中删除该框,以便后续处理
sort_cls_box = np.delete(sort_cls_box, 0, 0) # 当排序后的框列表不为空时,继续进行处理
while len(sort_cls_box) > 0: # 取出当前输出列表中最后添加的框(即当前的最大置信度框) max_conf_box = output_box[-1] # 用于记录需要删除的框的索引 del_index = [] # 遍历剩余的框,计算它们与当前最大框的IoU for j in range(len(sort_cls_box)): current_box = sort_cls_box[j] # 当前框 # 计算当前框与最大框之间的IoU iou = get_iou(max_conf_box, current_box) # 如果IoU大于设定的阈值,认为这两个框重叠过多 if iou > iou_thres: # 将该框的索引添加到待删除列表中 del_index.append(j) # 从排序框列表中删除重叠过多的框 sort_cls_box = np.delete(sort_cls_box, del_index, 0) # 如果还有剩余的框,继续处理 if len(sort_cls_box) > 0: # 取出下一个置信度最高的框并添加到输出列表中 output_box.append(sort_cls_box[0]) # 从排序列表中删除该框,以便后续处理 sort_cls_box = np.delete(sort_cls_box, 0, 0)
核心逻辑:
-
取当前剩余列表的 第一项(最高置信度)放到
output_box
。 -
把它与列表中其他框逐一计算 IoU:
若iou > iou_thres
⇒ 认为“与最大框重叠太多”,删除该框索引。 -
列表删完这些框后,如果还剩东西,就再取新的首项,重复步骤 2。
-
直到该类别框列表清空。
4.4 输出
return output_box
-
列表形式,每个元素是长度 6 的
numpy.ndarray
:[x, y, w, h, conf, class_id]
-
含义:对每个类别做完 NMS 后 剩下的保留框。
-
输出的
xywh
(框的中心坐标和宽高)通常是相对于模型输入的缩放尺寸(例如640x640)进行预测的实际像素值。若想直接送到绘图、评估代码,还需要进行一些转换处理,相关处理的文章可以看YOLOv8 预测结果添加面积过滤以及检测框坐标如何是从针对缩放转换为针对原图大小-CSDN博客