YOLOv8支持旋转框检测(OBB)任务随记
文章目录
- 前言
- 一、数据集制作
- 二、训练
- 2.1 数据增强的处理
- 2.2 头网络的输出 (scale - S)
- 2.3 损失函数的计算
- 三、推理
- 3.1 推理后处理的角度解析
- 参考资料
前言
u版的yolov8本身是支持obb任务,也就是旋转框检测任务的,这篇文章就是总结出如何快速上手使用yolov8-obb训练,以及解释训练和推理时obb任务比一般hbb任务(垂直框检测)的一些差异。
一、数据集制作
处理的过程中,把1:
所有坐标看为segments
;
也就是说把旋转框的坐标处理成分割任务的坐标。
且把只有4个坐标进行重采样,变成100个点:
代码位置:/ultralytics/ultralytics/data/dataset.py
二、训练
需要注意:
数据集制作为四个角点的形式,然后训练的时候利用预测的角度,把预测的x1y1x2y2转换为xywh,然后去进行iou计算loss,并没有直接计算angle的loss。
角度转换规则为:目标框相对于水平轴的逆时针旋转角度。
model_type = "yolov8s-obb"
print(f"model_type: {model_type}")model = YOLO(f'{model_type}.yaml').load(f'{model_type}.pt') # build from YAML and transfer weights# Train the model
model.train(data=yaml_file, epochs=100, imgsz=640, batch=8, patience=5000)
注:
- 对应模型类型要选yolov8的obb版本
- 这里
yaml_file
需要有垂直框检测的基础,数据集格式按照以上方式制作即可。
2.1 数据增强的处理
ultralytics/ultralytics/data/dataset.py
/data/thomascai/codes/ultralytics/ultralytics/data/augment.py
主要原理是对于segments
进行处理,最后再统一转为中心点、宽高和角度。
如:_mosaic4
ultralytics/ultralytics/data/augment.py
处理完之后,在Format中进行整体转换
ultralytics/ultralytics/data/augment.py
轮廓点 -> (cx, cy), (w, h), angle
ultralytics/ultralytics/utils/ops.py
这就是在训练数据迭代的box标签(归一化后的):如:0.0602, 0.3016, 0.0090, 0.0100, 0.3214 — 表示目标中心点的x坐标、y坐标、宽、高和角度。
2.2 头网络的输出 (scale - S)
ultralytics/ultralytics/nn/modules/head.py
OBB 输出2个元素
- 元素1:主输出 outputs
含义:[torch.Size([8, 65, 128, 128]),torch.Size([8, 65, 64, 64]),torch.Size([8, 65, 32, 32]) ]
-
8:batch size
-
65:每个位置的输出通道数,计算方式是:
65 = num_classes + self.reg_max * 4= 1 + 16 * 4 ← 示例中类别数为1 其中: 4 表示:[x1, y1, x2, y2](左上和右下点),16表示其到anchor点的距离分布
-
H x W:对应特征图大小(即感受野分辨率)
-
- 元素2:输出角度
含义:torch.Size([8, 1, 21504])
- 8:batch size
- 1:通道维度(可能用于统一格式)
- 21504:所有特征图上所有位置的预测结果拉平后的数量
你可以理解为:21504 = 128x128 + 64x64 + 32x32 = 16384 + 4096 + 1024
2.3 损失函数的计算
ultralytics/ultralytics/utils/loss.py
重点是把预测的坐标分布先通过dfl计算还原成
-> x1,y1,x2,y2 然后再通过 anchor和预测的angle 变换成
-> x,y,w,h (中心点xy和宽高)
(然后跟真实框坐obb版本iou)
注意:这里有个trick,就是在计算dfl loss的时候。
真实值通过简单的加减得到左上和右下的坐标,然后再计算与anchor点的距离作为真实值;
然后预测的分布和真实值去做dfl loss计算;
需要注意的是,这里是一个近似解法,因为真实值其实是旋转框,但因为是预测与anchor点的距离,所以这里比较接近,yolov8就这么近似解了。
target_ltrb 的 shape为:torch.Size([8, 21504, 4])
pred_dist 的 shape为:torch.Size([8, 21504, 64])
三、推理
3.1 推理后处理的角度解析
ultralytics/ultralytics/nn/modules/head.py
推理阶段,在detect的head网络中,走推理分支,然后直接得到中心点坐标、宽高+角度,这个就是传入后处理的主要元素(实际还有其他的,但主要用到这个)。
参考资料
- 仓库地址:https://github.com/ultralytics/ultralytics/tree/main/ultralytics
- 数据集制作参考:https://docs.ultralytics.com/zh/datasets/obb/#usage
- 训练和推理:https://docs.ultralytics.com/zh/tasks/obb/#dataset-format
以上,感谢阅读,AI路上分享所见所得,关注不迷路。
∼Onepersongofaster,agroupofpeoplecangofurther∼\sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim∼One person go faster, a group of people can go further∼