当前位置: 首页 > news >正文

YOLOX 检测头以及后处理

前言

详解yolox/models/yolo_head.py

源码地址:Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/


检测头

针对三个尺度特征使用三个检测头。

每个检测头中,将预测的结果(reg_output,obj_output,cls_output)在第一维进行拼接,变成[bs,5+num_classes,H,W],例如在20*20这个尺度下,那么结果就是[bs,5+num_classes,20,20]

针对每一个尺度的特征图,单独进行处理。

最后使用列表,保存各尺度的输出。

下面的操作,都是在for循环中,即只针对单尺度的操作。 


坐标解码

获取某一个尺度下的输出output进行解码。

(1)网格的生成与缓存

grid = self.grids[k]  # 尝试获取缓存的网格

# 如果网格尺寸不匹配当前特征图(或初次生成)
if grid.shape[2:4] != output.shape[2:4]:
    yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])  # 生成坐标网格
    grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
    self.grids[k] = grid  # 缓存网格

(2)输出张量重塑

output = output.view(batch_size, 1, n_ch, hsize, wsize)
output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, hsize * wsize, -1)
grid = grid.view(1, -1, 2)

将 [B,5+C,H,W] → [B,H*W,5+C],方便逐anchor处理。

(3)预测框解码

output[..., :2] = (output[..., :2] + grid) * stride  # 中心点解码
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride  # 宽高解码

中心点解码output[..., :2] 是模型预测的偏移量,加上网格坐标后乘以stride得到原图位置。

宽高解码output[..., 2:4] 是log空间预测值,通过expstride还原为宽高。

 x_{center}=(dx+grid_x)\times stride

y_{center}=(dy+grid_y)\times stride

w=e^{dw}\times stride

h=e^{dh}\times stride

(4) 返回

返回该尺度对应网格以及还原会原图尺寸的output

  • output形状:[B,H*W,5+C],只是其中的5表示的边框信息从20*20根据stride=32还原到原图尺寸(640*640)
  • grid形状为:[1,H*W,2]
  • x_shifts.append(grid[:, :, 0])   # 保存当前尺度所有网格的x坐标 [1,H*W,1]
  • y_shifts.append(grid[:, :, 1]) # 保存当前尺度所有网格的y坐标  [1,H*W,1]
  • expanded_strides.append( torch.zeros(1, grid.shape[1]).fill_(stride_this_level) .type_as(xin[0]) :存储每个网格对应的步长,用于缩放坐标到原图尺寸,例如这里为全为32的[1,20*20]的张量。

损失计算

(1)损失函数配置

(2)进入损失计算

  • 将每一个尺度下的output在第一维度拼接,例如第一个尺度20*20,[bs,20*20,5+C],第二个尺度40*40,第三个尺度80*80,拼接在一起的outputs形状为[bs,20*20+40*40+80*80,5+C]
  • lables表示一个batch下的GT真实标注信息

(3)提取还原回原尺寸的预测结果

 获取该批次下每张图片的目标数量

注意这里的labels形状为[batch_size,max_gt_num,5],这里的第一维是max_gt_num而不是该图片实际的GT数量。

  • 处理要求:深度学习的输入/输出必须是统一形状的张量(如 [B, N, 5]),因为PyTorch/TensorFlow等框架依赖静态计算图,无法直接处理可变长度的数据。
  • 填充(Padding):若一个batch中不同图像的GT数量不同(如图像1有3个GT,图像2有5个GT),必须通过填充(通常用0)将所有图像的GT数量对齐到相同的最大值(如max_gt_num=5),否则无法堆叠成量。

获取所有尺度下的anchor数量(20*20+40*40+80*80

 将所有尺度的网格信息拼接

(4) 遍历一个批次下的每张图片

获取每张图片的目标数量

如果没有目标,那么是由target都为零

如果图片有目标:

①  

 ② SimOTA动态正样本分配

③ 目标构造

分类目标:软标签(用IoU值加权one-hot标签,使高质量预测对分类损失贡献更大)

回归目标:

置信度目标:

L1辅助目标(可选):

(5)加权总损失

get_assignments

这个函数使用 SimOTA 动态标签分配策略,完成以下任务:

  1. 筛选候选正样本:通过几何约束(中心点距离)初步过滤锚点。
  2. 计算匹配代价:综合分类损失、回归损失(IoU)和几何关系。
  3. 动态分配正样本:为每个 GT 框选择代价最低的 Top-K 预测框。

(1)几何约束筛选

通过几何约束(中心点距离)筛选可能与 GT 框匹配的锚点(Anchor),减少候选正样本数量

输入参数:

gt_bboxes_per_image表示这张图片上所有gt的坐标框

 (1)计算锚点的实际中心坐标

  • x_shifts + 0.5:将网格坐标转换到锚点中心(如网格点 0 → 实际中心 0.5)。
  • 乘以 expanded_strides:将坐标映射回输入图像的尺度(如 stride=8 时,网格点 1 → 图像坐标 (1.5 * 8) = 12

(2) 定义GT框中心区域半径,并计算GT框的中心区域边界

gt_bboxes_per_image[:, 0:1]和gt_bboxes_per_image[:, 1:2]分别先取GT的中心坐标xy,然后根据半径确定中心区域的范围

(3)判断锚点是否在中心区域

这里是用的还原会输入图像尺寸的网格中心点去和GT的中心点做对比,判断网格中哪些位置是在GT的中心区域的。这样操作可以提前一步筛选掉不符合的预测框。

(4)生成筛选结果

(2) 筛选候选预测框

bboxes_preds_per_image在传入函数之前,就已经提取为每一张图片表示的预测框坐标,cls和obj没有提取,所以这里加了batch_idx索引

 (3)计算IOU 

 

boxes_a表示gt,输入形状为[num_gt,4];boxes_b表示pred,输入形状为[num_pred,4];在这里我们假设gt只有一个框,预测有两个框,即boxes_a[1,4],boxes_b[2,4]

① 计算 boxes_a 的角点

boxes_a[:, None :2]:将 boxes_a 的中心坐标扩展 [1, 1,4]请只取最后一维的前两个元素,即boxes_a[:, None :2]形状为[1, 1,2];boxes_a_x2y2同理

结果

boxes_a_x1y1形状 [1, 1, 2]。

boxes_a_x2y2形状 [1,1, 2]

② 计算 boxes_b 的角点

boxes_b_x1y1[2, 2],boxes_b_x2y2[2, 2];

计算交集区域(tl 和 br

广播过程:

boxes_a_x1y1[num_a,1,2]-->[num_a,num_b,2]

boxes_b_x1y1[num_b,2]--->[1,num_b,2]---->[num_a,num_b,2]

对boxes_b的第一框:

对boxes_b的第二个框:

结果:tl[num_a,num_b,2];br[num_a,num_b,2]

检查交际有效性

逐个比较tl与br,

结果:en[[1, 0]],形状 [num_a,num_b]

计算交集面积

  • br-tl:右下角坐标减左上角坐标(得到wh);
  • torch.prod(br - tl, 2)沿最后一个维度求乘积,即宽高方向

  • 乘上有效掩码,强制无效交集(en=0)的面积为0

最后返回IoU  [num_a,num_b]

对数处理

pair_wise_ious: 两组边界框之间计算的IOU矩阵,值为[0,1]范围内的浮点数(如0.8)。

    对数损失的梯度更陡峭(尤其是低IOU时),有助于模型快速聚焦于难例样本。通过非线性变换增强对低质量预测的敏感性。

    (4)GT独热编码

    gt_classes = labels[batch_idx, :num_gt, 0] #   提取当前图像所有GT框的类别ID

     F.one_hot()函数

    PyTorch提供的one-hot编码函数,第一个参数是要编码的类别索引,第二个参数self.num_classes指定总类别数

    举例:若gt_classes=[1,2,0],num_classes=4

    • 输出将是: [[0,1,0,0], [0,0,1,0], [1,0,0,0]]

    (5) 类别预测损失

    计算预测类别分数真实类别标签之间的二元交叉熵损失,通常用于:

    • 筛选正负样本(Anchor与GT的匹配质量)
    • 辅助计算最终损失函数

    ① 混合分数计算

    sigmoid(): 将原始logits转换为概率(0-1之间)

    • cls_preds_: 类别预测分数(形状如 [num_anchors, num_classes]
    • obj_preds_: 目标存在性预测分数(形状如 [num_anchors, 1]

    乘积操作: 将类别置信度与目标存在性置信度相乘(例如:某锚框预测"狗"的概率0.8 × 存在物体概率0.9 = 综合置信度0.72)

    sqrt(): 几何平均的数学等效,平衡两类分数的量纲差异

    同时考虑"是什么类别"和"是否有物体"的双重不确定性,避免高分预测无物体的锚框。

    ② 损失计算

    • 预测值:通过unsqueeze(0).repeat()扩展维度以匹配GT形状
    • 真实值:通过unsqueeze(1).repeat()确保与预测值对齐

    二元交叉熵:逐类别计算预测概率与GT的差异,GT为one-hot编码(如[0,1,0]表示类别1)

    reduction="none":保留每个锚框-类别对的独立损失值;sum(-1):合并所有类别的损失,得到每个锚框的总分类损失

    (6)  构建匹配代价函数matching cost

    对geometry_relation取反,原本是0,1,乘上了float(1e6),硬性排除空间位置不合理的匹配

    (7) SimOTA算法完成预测框(predictions)和真实框(ground truths)的匹配

     ① 初始化匹配矩阵

    ② 动态确定每个gt的匹配数量

    pair_wise_ious[num_gt,num_anchors_True]--->n_candidate_k=min(10,num_anchors_True]

    • topk_ious:每个gt对应的前K个最高IOU值,形状[num_gt, n_candidate_k]
    • _:忽略的索引值(不需要具体位置)

    为什么要sum?:IOU值在[0,1]区间,多个高IOU预测框的sum会更大,表示这个gt需要更多匹配

    torch.clamp(..., min=1)确保每个gt至少有1个匹配

    dynamic_ks[gt_idx]决定为每个gt选择多少cost最低的预测框

    ③ 初步匹配

    每个真实框(ground truth)动态选择最优的预测框

    ④ 处理多对一冲突

    对这些预测框,仅保留与cost最小的真实框的匹配

    生成最终匹配结果

    • g_mask_inboxes标记最终被选为正样本的预测框。
    • matched_gt_inds记录每个正样本匹配的真实框索引。
    • pred_ious_this_matching计算每个匹配对的IOU值。

    http://www.dtcms.com/a/113231.html

    相关文章:

  • 联网汽车陷入网络安全危机
  • 贪心算法之任务选择问题
  • mmap函数的概念和使用方案
  • 爬楼梯问题-动态规划
  • 3536 矩形总面积
  • leetcode4.寻找两个正序数组中的中位数
  • 类 和 对象 的介绍
  • 2024 .11-2025.3 一些新感悟
  • 【33期获取股票数据API接口】如何用Python、Java等五种主流语言实例演示获取股票行情api接口之沪深A股当天逐笔交易数据及接口API说明文档
  • 【2020】【论文笔记】相变材料与超表面——
  • 使用Cusor 生成 Figma UI 设计稿
  • 数据库并发控制问题
  • 麒麟系统桌面版本v10安装教程
  • 【动手学深度学习】卷积神经网络(CNN)入门
  • 低代码开发平台:飞帆画 echarts 柱状图
  • pygame里live2d的使用方法(live2d-py)
  • 人工智能与计算机技术赋能高中教育数字化教学模式的构建与实践
  • Git 分布式版本控制工具
  • 【ROS2】〇、ROS2的安装
  • 神经网络与深度学习:案例与实践——第三章(2)
  • 3D图像重建中Bundle Adjustment的推导与实现
  • Shell脚本笔记
  • Java第三节:新手如何用idea创建java项目
  • #SVA语法滴水穿石# (004)关于 ended 和 triggered 用法
  • Java HttpURLConnection修仙指南:从萌新到HTTP请求大能的渡劫手册
  • #SVA语法滴水穿石# (005)关于 问号表达式(condition ? expr1 : expr2)
  • Arduino示例代码讲解:ADXL3xx 加速传感器
  • Java 类型转换和泛型原理(JVM 层面)
  • 论定制开发开源 AI 智能名片 S2B2C 商城小程序源码在零售变革中的角色与价值
  • 基于sklearn实现文本摘要思考