YOLO v5详解(文字版)
序言:最近在整理Yolo v5的流程,我发现网上写的大多很零碎,把推理过程和训练过程混淆。经过我长达三天的整理终于对照代码将Yolo v5详细训练过程,和Yolo v5详细训推理过程总结如下。如果有什么细节问题,根据本篇文章在询问ai。或者需要手写推理代码也可参考推理的后处理来进行参照。希望能帮到需要的人,如果有什么错误,也欢迎指正。
(ps:由于不同版本的Yolo 5也有差别,请大家记得甄别仔细)
Yolo v5网络结构图
Yolo v5详细训练过程
训练过程主要关注特征提取和正样本选择
1.前处理:(较为琐碎,此处只是大框架)
1.1数据增强:
①Mosaic数据增强随机选取4张图像拼接为单张图像(1280×1280)提升小目标检测能力,增加背景复杂度。
②自适应图像缩放:保持原图长宽比缩放至最长边=640
③归一化与通道转换:像素值归一化至 [0,1], 维度转换:H×W×C → C×H×W
1.2自适应锚框计算:k-means聚类生成9个先验框(3个尺度×3种长宽比)
2.模型的前向传播:
2.1Backbone(主干网络提取特征):
①采用Focus 结构 (v6.0 之前): 早期版本在输入后使用 Focus 模块(本质是高效的切片操作 slice切片(2) + concat)。
1将输入图像的空间分辨率(H, W)减少到 1/4(H/2, W/2),同时将通道数增加 4 倍。
2这相当于一个高效的、无参数的“下采样”,保留了更多信息。
②特征提取主干 CSPDarknet53核心包含多个 C3 模块 (本质是CSP Bottleneck )。将特征图分成两部分:一部分直接传递到下一阶段,另一部分经过密集块处理后再与之前传递的部分融合。
1.融合浅层,中层,高层,不同尺寸的特征, 丰富梯度信息。
2.采用C3模块内部有残差过程,只有部分特征经过密集计算,减少了计算量。
3.多处跨阶段直连路径, 增强梯度流减缓梯度消失。
③SPPF(v6.0前SPP):三层池化(5,9,13)然后ConvModule,再concat。多尺度池化融合,适应不同尺寸目标,保留更多细节。串行复用设计减少计算量,提高推理速度。(SPP:是 三个池化层独立计算 → 重复计算严重,运算过慢。)
2.2Neck部分(特征融合)
①采用改进的PANNet:
1.接收了80×80×256(来自第3阶段C3),40×40×512(来自第4阶段C3),20×20×512(SPPF输出)的输出,
2.拥有自上而下,自下而上的过程,融合高层语义特征 和 低层细节特征,提升目标的检测能力。
3.与SPP模块互补,提取多尺度的全局特征,增强模型对不同尺度目标的适应性。
2.3Head(检测头):
①Yolo v5是耦合头(4边界框坐标+1目标置信度+85类别概率),分类回归共享卷积特征,参数量较少,推理更快,但是小目标检测略弱。
②小目标,中目标,大目标.对应输出的数据是后处理中的list={0:{Tensor:1,255,80,80} 1:{Tensor:1,255,40,40} 2:{Tensor:1,2555,20,20}},对应(1, 3, 85, 80, 80)
# [batch, anchor box, (85=4(边界框偏移量 dx,dy,dw,dh)+1 (目标置信度 )+80 num_classes), grid_h网格高,grid_w网格宽]
3.模型的反向传播:
3.1多正样本选择:
为每个真实目标(Ground Truth Box, GT)分配足够多的、高质量的正样本 Anchor
多gird:使用三个不同尺度的特征图(80×80, 40×40,20×20)每个网格点都负责预测。显著提升了多尺度物体检测.
多Anchor:每个网格点又预测 3 个不同长宽比和大小的 Anchor Boxes。现实世界物体的形状多种多样. 提高模型对不同长宽比物体的适应能力。
多正样本ATSS:(一个目标匹配多个网格Anchor)
①初步筛选(基于 Anchor 中心): 对于每个 GT Box:计算 GT Box 的中心点 (gx, gy)。找该中心点落在哪个Grid Cell。
②候选 Anchor 选择:考虑该 Grid Cell (i, j) 以及其相邻的 Grid Cells内的所有 Anchor Template。
③计算这个 GT Box 与步骤 2 中选出的所有候选 Anchor Boxes 的 IoU。选择 IoU超过阈值的Anchor。
④基于GT BOX与Anchor宽高比例的匹配策略(Anchor Template的边*0.25和*4能把GTBox 包起来为正样本,既形状相似度)
3.2构建loss
①边界框回归损失: CIoU Loss。CIoU Loss 比传统的 IoU Loss 或 GIoU Loss 能更全面地衡量框的相似度,收敛更快,定位更精准。
②目标置信度损失:二元交叉熵损失 Focal Loss通过 obj 和 cls 的 BCE 函数的 pos_weight 和 fl_gamma 参数实现)来降低大量负样本对总损失的贡献.
③类别损失:每个类别独立使用 Sigmoid 激活 + BCE Loss ,只对正样本计算此损失。
Yolo v5详细推理过程
推理过程主要关注数据形状和后处理部分。
1.前处理:
1.1数据增强
1.2图像缩放 (Letterbox Resizing)
1.3维度扩展与对齐(BCHW)
2.模型的前向传播:
2.1 Backbone(特征提取[1, 3, 640, 640]):Focus,CSPDarknet53,SPPF/SPP
2.2 Neck (特征融合)改进的PANNet: output ( [1, 256, 80, 80], [1, 512, 40, 40], [1, 1024, 20, 20])
2.3Head (耦合检测头):num_classes=80
1output_p3 = [1, 3, (5+num_classes), 80, 80] -> 变形 为 [1, 3*80*80, (5+num_classes)] = [1, 19200, (5+num_classes)]
2output_p4 = [1, 3, (5+num_classes), 40, 40] -> 变形为 [1, 3*40*40, (5+num_classes)] = [1, 4800, (5+num_classes)]
3output_p5 = [1, 3, (5+num_classes), 20, 20] -> 变形为 [1, 3*20*20, (5+num_classes)] = [1, 1200, (5+num_classes)]
4拼接 (Concatenate):torch.cat([output_p3, output_p4, output_p5], dim=1) -> [1, (19200+4800+1200), (5+num_classes)] = [1, 25200,85]
3.后处理:
A. 解码边界框 (Decode Boxes):
1.Sigmoid 激活: 对 dx, dy, obj 应用 sigmoid 函数,将其约束到 (0, 1) 范围。
2.计算网格中心坐标: 对于每个预测位置 (i, j) (网格坐标),计算其对应的特征图上的中心坐标
3.计算预测框中心: bx = sigmoid(dx) + cx, by = sigmoid(dy) + cy。(sigmoid(dx), sigmoid(dy)) 是相对于网格单元中心的偏移量。
4.计算预测框宽高: bw = anchor_w * exp(dw), bh = anchor_h * exp(dh)。(anchor_w, anchor_h) 是与该预测位置关联的 anchor 模板的宽度和高度。
5.计算绝对坐标 (像素): 将中心坐标 (bx, by) 和宽高 (bw, bh) 转换为图像坐标系下的边界框表示:
B. 应用置信度阈值
1.计算每个预测框的 最终置信度:conf = obj * max(cls_score)。max(cls_score) 是该预测框在所有类别上经过 sigmoid 后的最大类别概率。
2.丢弃所有 conf < conf_thres 的预测框。(大幅减少了候选框数量)
C. 类别概率处理
1.直接使用最大类别概率: 如上所述,在计算最终置信度 conf 时已经用了 max(cls_probs)
2.(可选)多标签模式 (Multi-label): 对每个类别单独设置一个阈值 multi_label_thres (例如 0.25),允许一个框同时属于多个类别
D. 非极大值抑制(NMS)
首先:按类别 class_id 对所有框进行分组(*max_value)然后对于每个类别进行下面过程
1.将该类别的所有框按置信度 conf 从高到低排序。
2.选取置信度最高的框 A 作为保留框。
3.计算框 A 与剩余所有框的 IoU
4.移除所有与框 A 的 IoU 超过设定阈值 iou_thres的框
5.从剩余框中再选取置信度最高的框 B 作为下一个保留框。循环以上操作。
输出:[batch_index, x1, y1, x2, y2, conf, class_id]
E. 尺度还原
1.当前检测框坐标 (x1, y1, x2, y2) 是基于 640x640 Letterbox 图像的。
2.将它们映射回原始未缩放、未填充的图像坐标系。
F. 输出 (Output)
最终输出格式1: list :[x1, y1, x2, y2, confidence, class_id] (绝对坐标)
最终输出格式2: list : [x_center, y_center, width, height, confidence, class_id]