当前位置: 首页 > news >正文

补充内容:YOLOv5损失函数解析+代码阅读

  上一篇因为篇幅和时间原因把YOLOv5的损失函数部分比较潦草地略过了,思来想去还是觉得既然开始做了就应该送佛送到西,本着这个原则这篇文章来专门的探讨YOLOv5的loss部分。

问题1:loss函数在哪实现?具体在训练过程中又是怎么实现的呢?

  首先,YOLOv5的损失函数实现在utils/yolo.py文件中。在具体的训练过程中train.py中import compute loss实例来计算损失。

损失函数主要实现文件loss.py

train.py中引入损失函数计算

ComputeLoss实例化1:初始化损失计算类

ComputeLoss实例化2:前向传播

问题二:loss函数具体是什么样的?有哪些组成部分?

1.目标框回归

在 YOLOv5 中,会预先利用k-means算法进行一些锚框,预先定义好的一系列宽高比例的框,它们代表了数据集中常见目标的形状,回归框这一部分是用于修正目标框和真实框的差距。在YOLOv5的回归框中有个重要的创新:即上图所说的sigmiod输出范围在(-0.5,1.5)之间。

具体代码实现:

 # 回归损失pxy = ps[:, :2].sigmoid() * 2. - 0.5  # 预测的中心点pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]  # 预测的宽高pbox = torch.cat((pxy, pwh), 1)  # 组合为预测框iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # 计算预测框与目标框的IoUlbox += (1.0 - iou).mean()  # 累加IoU损失

这段代码实现的是一种用于目标检测模型中的回归损失计算,主要包括以下几个步骤:

  1. 预测中心点 pxy

    • ps[:, :2].sigmoid():对预测的中心坐标进行sigmoid变换,使值在0到1之间。
    • 乘以2后减0.5,调整预测范围,得到更灵活的中心点预测。
  2. 预测宽高 pwh

    • ps[:, 2:4].sigmoid():对宽高预测值进行sigmoid变换。
    • 乘以2,扩大预测范围。
    • 转为平方** 2,以确保宽高正值且更平滑。
    • 乘以预设的锚点 anchors[i],缩放到实际尺度。
  3. 构建预测框 pbox

    • 将预测的中心点和宽高拼接在一起,形成完整的预测边界框。
  4. 计算预测框与真实目标框之间的IoU:

    • 使用 bbox_iou 函数,计算预测边界框和目标框 tbox[i] 之间的IoU。
    • 参数 x1y1x2y2=False 表示中心点宽高表示法。
    • CIoU=True,使用提升的IoU指标(Complete IoU)。
  5. 计算IoU损失:

    • (1.0 - iou).mean():计算IoU的补集作为损失(IoU越大,损失越小)。
    • 累加到 lbox,累计整个批次的边框回归损失。

这段代码的核心思想是通过预测框与真实框的IoU来衡量边界框回归的准确性,IoU值越大,预测越准确,损失越小,是目标检测模型训练中的常用方式

2.正负样本匹配

目标置信度代码实现:

 # 目标置信度损失score_iou = iou.detach().clamp(0).type(tobj.dtype)  # 处理IoU得分if self.sort_obj_iou:  # 如果需要排序sort_id = torch.argsort(score_iou)b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou  # 计算目标置信度
 # 分类损失if self.nc > 1:  # 如果有多个类别t = torch.full_like(ps[:, 5:], self.cn, device=device)  # 初始化目标t[range(n), tcls[i]] = self.cp  # 设置正样本lcls += self.BCEcls(ps[:, 5:], t)  # 计算分类的BCE损失
          # 计算目标置信度损失obji = self.BCEobj(pi[..., 4], tobj)lobj += obji * self.balance[i]  # 累加对象损失if self.autobalance:  # 自动平衡损失self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  1. 预测中心点 pxy

    • ps[:, :2].sigmoid():对预测的中心坐标进行sigmoid变换,使值在0到1之间。
    • 乘以2后减0.5,调整预测范围,得到更灵活的中心点预测。
  2. 预测宽高 pwh

    • ps[:, 2:4].sigmoid():对宽高预测值进行sigmoid变换。
    • 乘以2,扩大预测范围。
    • 转为平方** 2,以确保宽高正值且更平滑。
    • 乘以预设的锚点 anchors[i],缩放到实际尺度。
  3. 构建预测框 pbox

    • 将预测的中心点和宽高拼接在一起,形成完整的预测边界框。
  4. 计算预测框与真实目标框之间的IoU:

    • 使用 bbox_iou 函数,计算预测边界框和目标框 tbox[i] 之间的IoU。
    • 参数 x1y1x2y2=False 表示中心点宽高表示法。
    • CIoU=True,使用提升的IoU指标(Complete IoU)。
  5. 计算IoU损失:

    • (1.0 - iou).mean():计算IoU的补集作为损失(IoU越大,损失越小)。
    • 累加到 lbox,累计整个批次的边框回归损失。

这段代码的核心思想是通过预测框与真实框的IoU来衡量边界框回归的准确性,IoU值越大,预测越准确,损失越小,是目标检测模型训练中的常用方式。

3.整体总结

http://www.dtcms.com/a/569092.html

相关文章:

  • 北仑网站建设培训学校游戏开发需要什么学历
  • 高端装备制造提速,紧固件标准化与智能化升级成为行业新焦点
  • 6项提高电机制造质量的电气测试方案
  • 09_FastMCP 2.x 中文文档之FastMCP高级功能服务器组成详解
  • 工业之“眼”的进化:基于MEMS扫描的主动式3D视觉如何驱动柔性制造
  • 基于管理会计的制造企业运营优化虚拟仿真实验
  • 工业制造领域的ODM、OEM、EMS、JDM、CM、OBM都是啥
  • 建设网站要用什么软件.net程序员网站开发工程师
  • day07(11.4)——leetcode面试经典150
  • java源代码、字节码、jvm、jit、aot的关系
  • JVM 垃圾收集器介绍
  • springcloud:理解springsecurity安全架构与认证链路(二)RBAC 权限模型与数据库设计
  • 自适应网站建设电话网站dns错误
  • 上海网站建设上海迈歌玉树营销网站建设哪家好
  • [5-01-01].第03节:JVM启航 - JVM架构
  • 2024CISCN ezjava复现
  • Cursor 项目实战:AI播客策划助手(二)—— 多轮交互打磨播客文案的技术实现与实践
  • JavaScript的Web APIs 入门到实战(day2):事件监听与交互实现,轻松实现网页交互效果(附练习巩固)
  • 网站建设难么深圳网站制作服
  • 使用vue Template version: 1.3.1时, 设置的env无法正常读取
  • HOT100题打卡第28天——位运算
  • EasyOCR的模型放在了哪里
  • 18、【Ubuntu】【远程开发】技术方案分析:私网ip掩码
  • 做购物网站哪个cms好用企业支付的网站开发费如何入帐
  • 怎样将自己做的网站给别人看微信小程序网站建设
  • 【软考】信息系统项目管理师-质量管理论文范文
  • (T24) 跨时钟域SI->Q path的latch选型
  • 学习记录记录记录记录
  • 【JAVA】基础(一)
  • Coze-AI智能体开发平台4-应用