当前位置: 首页 > news >正文

开源项目实战学习之YOLO11:ultralytics-cfg-models-nas(十)

👉 点击关注不迷路
👉 点击关注不迷路
👉 点击关注不迷路


文章大纲

    • 1. __init__.py
    • 2. model.py
    • 3. predict.py
    • 4. val.py
    • 5.YOLO-NAS 模型优劣势
      • 5.1 优势
      • 5.2 劣势
    • 6.实际应用案例
      • 6.1 交通领域
      • 6.2 工业领域
      • 6.3 安防领域

  • 在这里插入图片描述
  • 在 YOLO(You Only Look Once)目标检测框架里,models/nas 关联着神经架构搜索(Neural Architecture Search,NAS)相关的模型和代码
    • 神经架构搜索(NAS)概念
      • 传统的深度学习模型架构设计主要依赖专家的经验和大量的实验尝试。
      • NAS 是一种自动化的技术,它能够在给定的搜索空间里自动探寻最优的神经网络架构,从而降低人工设计架构的工作量和难度,并且有可能找到性能更优的模型架构。

1. init.py

  • # 从当前包的 model 模块中导入 NAS 类
    from .model import NAS# 从当前包的 predict 模块中导入 NASPredictor 类
    from .predict import NASPredictor# 从当前包的 val 模块中导入 NASValidator 类
    from .val import NASValidator# 定义 __all__ 变量,用于控制当使用 from package import * 语句时导入的对象
    # 这里指定了三个对象,当使用上述导入语句时,会导入 NASPredictor、NASValidator 和 NAS 这三个对象
    __all__ = "NASPredictor", "NASValidator", "NAS"
    

2. model.py

  • # 从 pathlib 模块导入 Path 类,Path 类可用于处理文件路径和目录路径,提供了跨平台的路径操作方法
    from pathlib import Path# 导入 PyTorch 库,PyTorch 是一个深度学习框架,提供了张量计算、自动求导等功能,用于构建和训练神经网络
    import torch# 从 ultralytics 库的 engine.model 模块导入 Model 类,该类是 ultralytics 框架中模型的基类,自定义的模型可以继承这个类
    from ultralytics.engine.model import Model# 从 ultralytics 库的 utils 模块导入 DEFAULT_CFG_DICT,这通常是一个包含默认配置信息的字典,用于初始化模型或其他操作
    from ultralytics.utils import DEFAULT_CFG_DICT# 从 ultralytics 库的 utils.downloads 模块导入 attempt_download_asset 函数,该函数可用于尝试下载模型资产,如预训练模型权重文件
    from ultralytics.utils.downloads import attempt_download_asset# 从 ultralytics 库的 utils.torch_utils 模块导入 model_info 函数,该函数用于获取模型的相关信息,如模型结构、参数量等
    from ultralytics.utils.torch_utils import model_info# 从当前包的 predict 模块导入 NASPredictor 类,该类可能是用于执行预测任务的预测器类
    from .predict import NASPredictor# 从当前包的 val 模块导入 NASValidator 类,该类可能是用于验证模型性能的验证器类
    from .val import NASValidatorclass NAS(Model):def __init__(self, model: str = "yolo_nas_s.pt") -> None:"""初始化NAS模型,可传入指定的模型文件,默认使用 "yolo_nas_s.pt"。:param model: 模型文件的路径或名称"""# 断言传入的模型文件不能是.yaml或.yml格式,因为YOLO - NAS模型仅支持预训练模型assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."# 调用父类Model的构造函数,初始化模型并指定任务类型为目标检测super().__init__(model, task="detect")def _load(self, weights: str, task=None) -> None:"""加载模型权重。:param weights: 模型权重文件的路径或名称:param task: 任务类型(此参数在当前方法中未使用)"""import super_gradients# 获取权重文件的后缀suffix = Path(weights).suffixif suffix == ".pt":# 如果是.pt文件,尝试下载该文件并使用torch.load加载模型self.model = torch.load(attempt_download_asset(weights))elif suffix == "":# 如果没有后缀,使用super_gradients库从预训练模型中获取指定名称的模型,预训练权重基于COCO数据集self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")# 重写模型的forward方法,忽略额外的参数def new_forward(x, *args, **kwargs):"""忽略额外的__call__参数,只传递输入x给模型的原始forward方法。:param x: 输入数据:param args: 额外的位置参数:param kwargs: 额外的关键字参数:return: 模型的输出"""return self.model._original_forward(x)# 保存模型的原始forward方法self.model._original_forward = self.model.forward# 将模型的forward方法替换为新定义的方法self.model.forward = new_forward# 标准化模型属性,以便后续使用# 定义fuse方法,返回模型本身,不进行实际的融合操作self.model.fuse = lambda verbose=True: self.model# 定义模型的步长为32self.model.stride = torch.tensor([32])# 定义模型的类别名称,通过枚举模型的类别名称列表生成字典self.model.names = dict(enumerate(self.model._class_names))# 定义is_fused方法,返回False,表示模型未融合self.model.is_fused = lambda: False# 定义yaml属性为空字典,用于info方法self.model.yaml = {}# 定义pt_path属性为权重文件路径,用于导出模型self.model.pt_path = weights# 定义task属性为目标检测,用于导出模型self.model.task = "detect"# 合并默认配置字典和自定义配置字典,作为模型的参数,用于导出模型self.model.args = {**DEFAULT_CFG_DICT, **self.overrides}# 将模型设置为评估模式self.model.eval()def info(self, detailed: bool = False, verbose: bool = True):"""打印模型的信息。:param detailed: 是否打印详细信息:param verbose: 是否输出详细日志:return: 模型信息"""return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)@propertydef task_map(self):"""返回一个字典,将任务类型映射到相应的预测器和验证器类。:return: 任务映射字典"""return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
    

3. predict.py

  • # 导入 PyTorch 库,它是一个用于深度学习的强大框架,提供了张量计算、自动求导等功能
    import torch# 从 ultralytics 库的 models.yolo.detect.predict 模块导入 DetectionPredictor 类
    # DetectionPredictor 是 ultralytics 中用于目标检测预测的基类
    from ultralytics.models.yolo.detect.predict import DetectionPredictor# 从 ultralytics 库的 utils 模块导入 ops 工具模块
    # 该模块包含了一些常用的操作函数,例如坐标转换等
    from ultralytics.utils import ops# 定义 NASPredictor 类,它继承自 DetectionPredictor 类
    # 这意味着 NASPredictor 类将拥有 DetectionPredictor 类的属性和方法,并且可以对其进行扩展或重写
    class NASPredictor(DetectionPredictor):def postprocess(self, preds_in, img, orig_imgs):"""对模型的预测结果进行后处理。参数:preds_in (list): 模型的原始预测结果。img (torch.Tensor): 经过预处理后的输入图像张量。orig_imgs (list or torch.Tensor): 原始输入图像。返回:经过后处理的预测结果。"""# 使用 ops.xyxy2xywh 函数将预测框的坐标从 [x1, y1, x2, y2](左上角和右下角坐标)# 转换为 [x, y, w, h](中心点坐标和宽高)格式boxes = ops.xyxy2xywh(preds_in[0][0])# 将转换后的边界框坐标与类别分数进行拼接# preds_in[0][1] 是类别分数# 拼接后的张量在最后一个维度上进行拼接,然后对维度进行重新排列,以便后续处理preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)# 调用父类 DetectionPredictor 的 postprocess 方法,对拼接和调整后的预测结果进行进一步的后处理# 例如非极大值抑制等操作return super().postprocess(preds, img, orig_imgs)
    

4. val.py

  • # 导入 PyTorch 库,用于深度学习中的张量计算、模型构建等操作
    import torch# 从 ultralytics 库的 models.yolo.detect 模块导入 DetectionValidator 类
    # 该类是用于目标检测验证的基类
    from ultralytics.models.yolo.detect import DetectionValidator# 从 ultralytics 库的 utils 模块导入 ops 工具模块
    # 该模块包含了一些常用的操作函数,例如坐标转换、非极大值抑制等
    from ultralytics.utils import ops# 定义 __all__ 变量,指定当使用 from module import * 语句时,要导入的对象
    __all__ = ["NASValidator"]class NASValidator(DetectionValidator):def postprocess(self, preds_in):"""对模型的原始预测结果进行后处理,主要应用非极大值抑制来过滤预测框。参数:preds_in (torch.Tensor): 模型的原始预测结果,通常是一个包含边界框坐标和类别分数的张量返回:torch.Tensor: 经过后处理(如非极大值抑制)后的预测结果"""# 使用 ops.xyxy2xywh 函数将预测框的坐标从 [x1, y1, x2, y2](左上角和右下角坐标)# 转换为 [x, y, w, h](中心点坐标和宽高)格式boxes = ops.xyxy2xywh(preds_in[0][0])# 将转换后的边界框坐标与类别分数进行拼接# preds_in[0][1] 是类别分数# 拼接后的张量在最后一个维度上进行拼接,然后对维度进行重新排列,以便后续处理preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)# 调用父类 DetectionValidator 的 postprocess 方法,对拼接和调整后的预测结果进行进一步的后处理# 父类的 postprocess 方法通常会执行非极大值抑制等操作,以去除重叠和低置信度的预测框return super().postprocess(preds)
    

5.YOLO-NAS 模型优劣势

  • YOLO - NAS 模型是一种基于深度学习的目标检测模型,与其他目标检测模型相比,具有以下优势和劣势:

5.1 优势

  • 高效的检测速度: YOLO - NAS 采用了神经架构搜索(NAS)技术来优化网络结构,使其能够在各种硬件平台上实现高效的推理速度。这意味着它可以快速处理图像或视频流,适用于实时性要求较高的应用场景,如自动驾驶、视频监控等
  • 高精度: 通过在大规模数据集上进行训练和优化,YOLO - NAS 能够学习到丰富的图像特征,从而在目标检测任务中取得较高的精度。在一些公开的数据集上,YOLO - NAS 的性能表现优于许多传统的目标检测模型。
  • 多尺度检测能力: YOLO - NAS 模型能够处理不同尺度的目标物体,无论是小目标还是大目标,都能有较好的检测效果。它通过在不同的特征层上进行检测,实现了对多尺度目标的自适应感知,提高了检测的全面性和准确性。
  • 端到端的检测: YOLO - NAS 是一个端到端的目标检测模型,从输入图像到输出检测结果,整个过程可以在一个模型中完成,无需像一些传统方法那样进行复杂的步骤组合,如候选区域生成、特征提取和分类等。这种端到端的设计简化了检测流程,提高了模型的实用性和可部署性。

5.2 劣势

  • 对小目标检测的局限性: 尽管 YOLO - NAS 具有多尺度检测能力,但在一些极端情况下,对于非常小的目标,其检测效果可能不如专门针对小目标优化的模型。这是因为小目标在图像中所占像素较少,特征相对不明显,容易被模型忽略或误判
  • 训练成本较高: YOLO - NAS 模型通常具有较大的规模和复杂的网络结构,需要大量的计算资源和时间来进行训练。这不仅要求训练设备具备强大的 GPU 性能,还需要较长的训练时间才能使模型收敛到较好的性能。对于一些资源有限的用户或场景,训练 YOLO - NAS 模型可能会面临一定的困难。
  • 对复杂背景的适应性: 在一些复杂背景的图像中,YOLO - NAS 模型可能会受到干扰,导致检测精度下降。例如,当目标物体与背景的颜色、纹理等特征较为相似时,模型可能难以准确地将目标从背景中分离出来,从而出现误检或漏检的情况。

6.实际应用案例

6.1 交通领域

  • **航拍车辆检测:**随着航拍技术发展,从航拍图像中检测车辆需求迫切。YOLO - NAS 可利用如 UAVDT 或 Dota 等公开航拍图像数据集进行训练,用于检测航拍图像中的车辆,在交通管理、城市规划和环境监测等方面发挥作用,帮助分析交通流量、规划道路建设以及监测城市发展对环境的影响等。
  • 智能交通监控: 部署在道路监控摄像头中的 YOLO - NAS 模型,能够实时监测道路上的车辆、行人、交通标志和信号灯等目标。可以实现车辆流量统计、违章行为识别(如闯红灯、违规变道、超速等),为交通管理部门提供数据支持,以便优化交通信号控制、制定交通管理策略,提高道路通行效率和安全性。

6.2 工业领域

  • 产品质量检测: 在工业生产线上,YOLO - NAS 可用于检测产品的外观缺陷,如电子产品的外壳划痕、裂缝,汽车零部件的表面瑕疵等。通过对生产过程中的产品图像进行实时分析,快速识别出有缺陷的产品,实现自动化的质量检测,提高生产效率和产品质量,降低人工检测的成本和误差。
  • 物流仓储管理: 在物流仓库中,YOLO-NAS 可以帮助识别货物、托盘、货架等目标物体。用于货物的定位与跟踪,实现自动化的库存盘点、货物分拣和上架指引等功能,提高仓储管理的效率和准确性,减少人工操作的错误和时间成本。

6.3 安防领域

  • 视频监控与入侵检测: 在安防监控系统中,YOLO-NAS 模型可以对监控视频中的人员、物体进行实时检测和跟踪。能够及时发现异常行为,如非法闯入、徘徊、物品遗留等,实现智能安防预警,提高安防监控的效率和准确性,减少人力监控的盲区和疲劳带来的漏检问题。
  • 周界防范与行为分析: 用于保护重要区域的周界安全,如机场、监狱、军事基地等。通过对周界监控视频的分析,YOLO-NAS 可以识别是否有人试图翻越围墙、穿越警戒线等行为,并及时发出警报,保障区域的安全。

相关文章:

  • AVInputFormat 再分析
  • 1penl配置
  • 【LeetCode Hot100】二分查找篇
  • 【Go类库分享】mcp-go Go搭建MCP服务
  • 将Airtable导入NocoDB
  • Python functools.partial 函数深度解析与实战应用
  • 【C/C++】Linux的futex锁
  • 音视频开发技术总结报告
  • 小土堆pytorch数据加载概念以及实战
  • StandardCopyOption 还有哪些其他可用的常量?
  • 为什么要做异地监控组网?
  • 洛谷P6136 【模板】普通平衡树(数据加强版)
  • quantization-大模型权重量化简介
  • 【LLaMA-Factory实战】Web UI快速上手:可视化大模型微调全流程
  • Python 学习
  • react18基础速成
  • mysql安装,操作详解,适用于所有版本
  • 神经网络基础-从零开始搭建一个神经网络
  • Python实例题:Python获取房天下数据
  • 【算法基础】快速排序算法 - JAVA
  • 国内外数十支搜救犬队伍齐聚三明,进行废墟搜救等实战
  • 安徽安庆市委书记张祥安调研假日经济和旅游安全工作
  • 一周文化讲座|那些年的年青人
  • 武汉大学新闻与传播学院已由“80后”副院长吴世文主持工作
  • 5月人文社科中文原创好书榜|巫蛊:中国文化的历史暗流
  • 微软上财季净利增长18%:云业务增速环比提高,业绩指引高于预期