YOLOX 检测头以及后处理
前言
详解yolox/models/yolo_head.py
源码地址:Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/
检测头
针对三个尺度特征使用三个检测头。
在每个检测头中,将预测的结果(reg_output,obj_output,cls_output)在第一维进行拼接,变成[bs,5+num_classes,H,W],例如在20*20这个尺度下,那么结果就是[bs,5+num_classes,20,20]
针对每一个尺度的特征图,单独进行处理。
最后使用列表,保存各尺度的输出。
下面的操作,都是在for循环中,即只针对单尺度的操作。
坐标解码
获取某一个尺度下的输出output进行解码。
(1)网格的生成与缓存
grid = self.grids[k] # 尝试获取缓存的网格
# 如果网格尺寸不匹配当前特征图(或初次生成)
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) # 生成坐标网格
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid # 缓存网格
(2)输出张量重塑
output = output.view(batch_size, 1, n_ch, hsize, wsize)
output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, hsize * wsize, -1)
grid = grid.view(1, -1, 2)
将 [B,5+C,H,W]
→ [B,H*W,5+C]
,方便逐anchor处理。
(3)预测框解码
output[..., :2] = (output[..., :2] + grid) * stride # 中心点解码
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride # 宽高解码
中心点解码:output[..., :2]
是模型预测的偏移量,加上网格坐标后乘以stride
得到原图位置。
宽高解码:output[..., 2:4]
是log空间预测值,通过exp
和stride
还原为宽高。
(4) 返回
返回该尺度对应网格以及还原会原图尺寸的output
- output形状:
[B,H*W,5+C],只是其中的5表示的边框信息从20*20根据stride=32还原到原图尺寸(640*640)
grid形状为:[1,H*W,2]
- x_shifts.append(grid[:, :, 0]) # 保存当前尺度所有网格的x坐标 [1,H*W,1]
- y_shifts.append(grid[:, :, 1]) # 保存当前尺度所有网格的y坐标 [1,H*W,1]
- expanded_strides.append( torch.zeros(1, grid.shape[1]).fill_(stride_this_level) .type_as(xin[0]) :存储每个网格对应的步长,用于缩放坐标到原图尺寸,例如这里为全为32的[1,20*20]的张量。
损失计算
(1)损失函数配置
(2)进入损失计算
- 将每一个尺度下的output在第一维度拼接,例如第一个尺度20*20,[bs,20*20,5+C],第二个尺度40*40,第三个尺度80*80,拼接在一起的outputs形状为[bs,20*20+40*40+80*80,5+C]
- lables表示一个batch下的GT真实标注信息
(3)提取还原回原尺寸的预测结果
获取该批次下每张图片的目标数量
注意这里的labels形状为[batch_size,max_gt_num,5],这里的第一维是max_gt_num而不是该图片实际的GT数量。
- 处理要求:深度学习的输入/输出必须是统一形状的张量(如
[B, N, 5]
),因为PyTorch/TensorFlow等框架依赖静态计算图,无法直接处理可变长度的数据。- 填充(Padding):若一个batch中不同图像的GT数量不同(如图像1有3个GT,图像2有5个GT),必须通过填充(通常用0)将所有图像的GT数量对齐到相同的最大值(如
max_gt_num=5
),否则无法堆叠成量。
获取所有尺度下的anchor数量(20*20+40*40+80*80)
将所有尺度的网格信息拼接
(4) 遍历一个批次下的每张图片
获取每张图片的目标数量
如果没有目标,那么是由target都为零
如果图片有目标:
①
② SimOTA动态正样本分配
③ 目标构造
分类目标:软标签(用IoU值加权one-hot标签,使高质量预测对分类损失贡献更大)
回归目标:
置信度目标:
L1辅助目标(可选):
(5)加权总损失
get_assignments
这个函数使用 SimOTA 动态标签分配策略,完成以下任务:
- 筛选候选正样本:通过几何约束(中心点距离)初步过滤锚点。
- 计算匹配代价:综合分类损失、回归损失(IoU)和几何关系。
- 动态分配正样本:为每个 GT 框选择代价最低的 Top-K 预测框。
(1)几何约束筛选
通过几何约束(中心点距离)筛选可能与 GT 框匹配的锚点(Anchor),减少候选正样本数量
输入参数:
gt_bboxes_per_image表示这张图片上所有gt的坐标框
(1)计算锚点的实际中心坐标
x_shifts + 0.5
:将网格坐标转换到锚点中心(如网格点0
→ 实际中心0.5
)。- 乘以
expanded_strides
:将坐标映射回输入图像的尺度(如stride=8
时,网格点1
→ 图像坐标(1.5 * 8) = 12
)(2) 定义GT框中心区域半径,并计算GT框的中心区域边界
gt_bboxes_per_image[:, 0:1]和gt_bboxes_per_image[:, 1:2]分别先取GT的中心坐标xy,然后根据半径确定中心区域的范围
(3)判断锚点是否在中心区域
这里是用的还原会输入图像尺寸的网格中心点去和GT的中心点做对比,判断网格中哪些位置是在GT的中心区域的。这样操作可以提前一步筛选掉不符合的预测框。
(4)生成筛选结果
(2) 筛选候选预测框
bboxes_preds_per_image在传入函数之前,就已经提取为每一张图片表示的预测框坐标,cls和obj没有提取,所以这里加了batch_idx索引
(3)计算IOU
boxes_a表示gt,输入形状为[num_gt,4];boxes_b表示pred,输入形状为[num_pred,4];在这里我们假设gt只有一个框,预测有两个框,即boxes_a[1,4],boxes_b[2,4]
① 计算
boxes_a
的角点
boxes_a[:, None :2]
:将boxes_a
的中心坐标扩展[1, 1,4]请只取最后一维的前两个元素,即boxes_a[:, None :2]形状为[1, 1,2];
boxes_a_x2y2同理结果:
boxes_a_x1y1
形状
[1, 1, 2]。
boxes_a_x2y2
形状[1,1, 2]
。② 计算
boxes_b
的角点
boxes_b_x1y1[2, 2],boxes_b_x2y2[2, 2];
③计算交集区域(
tl
和br
)
广播过程:
boxes_a_x1y1[num_a,1,2]-->[num_a,num_b,2]
boxes_b_x1y1[num_b,2]--->[1,num_b,2]---->[num_a,num_b,2]
对boxes_b的第一框:
对boxes_b的第二个框:
结果:tl[num_a,num_b,2];br[num_a,num_b,2]
④检查交际有效性
逐个比较tl与br,
结果:
en
:[[1, 0]]
,形状[num_a,num_b]
⑤计算交集面积
- br-tl:右下角坐标减左上角坐标(得到wh);
torch.prod(br - tl, 2)沿最后一个维度求乘积,
即宽高方向乘上有效掩码,强制无效交集(
en=0
)的面积为0⑥ 最后返回IoU [num_a,num_b]
⑦对数处理
pair_wise_ious
: 两组边界框之间计算的IOU矩阵,值为[0,1]范围内的浮点数(如0.8)。
对数损失的梯度更陡峭(尤其是低IOU时),有助于模型快速聚焦于难例样本。通过非线性变换增强对低质量预测的敏感性。
(4)GT独热编码
gt_classes = labels[batch_idx, :num_gt, 0] # 提取当前图像所有GT框的类别ID
F.one_hot()函数
PyTorch提供的one-hot编码函数,第一个参数是要编码的类别索引,第二个参数self.num_classes指定总类别数
举例:若gt_classes=[1,2,0],num_classes=4
- 输出将是: [[0,1,0,0], [0,0,1,0], [1,0,0,0]]
(5) 类别预测损失
计算预测类别分数
与真实类别标签
之间的二元交叉熵损失,通常用于:
- 筛选正负样本(Anchor与GT的匹配质量)
- 辅助计算最终损失函数
① 混合分数计算
sigmoid()
: 将原始logits转换为概率(0-1之间)
cls_preds_
: 类别预测分数(形状如[num_anchors, num_classes]
)obj_preds_
: 目标存在性预测分数(形状如[num_anchors, 1]
)乘积操作: 将类别置信度与目标存在性置信度相乘(例如:某锚框预测"狗"的概率0.8 × 存在物体概率0.9 = 综合置信度0.72)
sqrt()
: 几何平均的数学等效,平衡两类分数的量纲差异同时考虑"是什么类别"和"是否有物体"的双重不确定性,避免高分预测无物体的锚框。
② 损失计算
- 预测值:通过
unsqueeze(0).repeat()
扩展维度以匹配GT形状- 真实值:通过
unsqueeze(1).repeat()
确保与预测值对齐二元交叉熵:逐类别计算预测概率与GT的差异,GT为one-hot编码(如
[0,1,0]
表示类别1)reduction="none":保留每个锚框-类别对的独立损失值;sum(-1):合并所有类别的损失,得到每个锚框的总分类损失
(6) 构建匹配代价函数matching cost
对geometry_relation取反,原本是0,1,乘上了float(1e6),硬性排除空间位置不合理的匹配
(7) SimOTA算法完成预测框(predictions)和真实框(ground truths)的匹配
① 初始化匹配矩阵
② 动态确定每个gt的匹配数量
pair_wise_ious[num_gt,num_anchors_True]--->n_candidate_k=min(10,num_anchors_True]
topk_ious
:每个gt对应的前K个最高IOU值,形状[num_gt, n_candidate_k]
_
:忽略的索引值(不需要具体位置)
为什么要sum?:IOU值在[0,1]区间,多个高IOU预测框的sum会更大,表示这个gt需要更多匹配
torch.clamp(..., min=1)
确保每个gt至少有1个匹配
用
dynamic_ks[gt_idx]
决定为每个gt选择多少cost最低的预测框③ 初步匹配
每个真实框(ground truth)动态选择最优的预测框
④ 处理多对一冲突
对这些预测框,仅保留与
cost
最小的真实框的匹配
⑤ 生成最终匹配结果
g_mask_inboxes
标记最终被选为正样本的预测框。matched_gt_inds
记录每个正样本匹配的真实框索引。pred_ious_this_matching
计算每个匹配对的IOU值。