当前位置: 首页 > news >正文

(7)机器学习小白入门 YOLOv:机器学习模型训练详解

— (1)机器学习小白入门YOLOv :从概念到实践
(2)机器学习小白入门 YOLOv:从模块优化到工程部署
(3)机器学习小白入门 YOLOv: 解锁图片分类新技能
(4)机器学习小白入门YOLOv :图片标注实操手册
(5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
(6)机器学习小白入门 YOLOv:图片的数据预处理
(7)机器学习小白入门 YOLOv:模型训练详解

一、训练模型的基本概念和原理简介

原理简介

模型通过一个称为反向传播的过程反复进行预测、计算误差和更新参数。在此过程中,模型会调整其内部参数 (weights and biases) 以减少误差。通过多次重复这一循环,模型逐渐提高了准确性。随着时间的推移,它就能学会识别形状、颜色和纹理等复杂模式。
在这里插入图片描述

模型作用:

YOLOv8 是一个用于 目标检测 的深度学习模型。它的任务是在一张图片中识别出图像中存在的对象,并标记它们的位置(通常用边界框表示)以及每个对象的类别标签。

目标函数(Loss):

在训练过程中,我们需要最小化损失函数。YOLOv8 模型的损失主要包括以下几个部分:

  • 分类损失:判断模型输出的物体类别是否正确。
  • 定位损失:衡量边界框位置与真实目标之间的差距。
  • 置信度损失:判断一个预测框是否包含一个目标对象。

优化器的作用:

使用优化器来更新神经网络中的权重参数,以最小化损失函数。常见优化器有 SGDAdamRMSPropAdamW 等。


二、训练模型的步骤详解

以下是基于 Python 和 YOLOv8 训练目标检测模型的整体流程:

步骤 1:导入必要的库

import torch
from ultralytics import YOLO

注意:

  • ultralytics 是官方提供的对YOLOv8的封装。
  • 使用前需通过 pip 安装 ultralytics 包:
    pip install ultralytics
    

步骤 2:加载预训练模型

model = YOLO("yolov8n.pt")  # 加载YOLOv8小型目标检测预训练权重,也可以使用 'yolov8s.pt'、'yolov8m.pt' 等不同大小的模型
  • yolov8n.pt: 小型网络。
  • yolov8m.pt: 中型网络(默认)。
  • yolov8l.pt/yolov8x.pt: 大型网络,检测精度高但推理速度较慢。

步骤 3:定义训练配置

# 定义模型的训练配置参数
config = {"epochs": 100,         # 总共训练多少轮"imgsz": 640,          # 图像缩放大小"batch_size": 16,      # 每次送入网络的数据量(影响训练速度和内存占用)"workers": 8,          # 数据加载线程数,加速数据读取"device": 'cuda',      # 使用GPU进行训练,'cpu'也可以使用CPU进行训练但会更慢"project": "yolov8_training",     # 训练结果保存的根目录名称(可自定义)"name": "train_exp"        # 当前实验的名字,在 project 文件夹中形成子文件夹
}

步骤 4:开始训练模型

results = model.train(data='coco.yaml',           # 数据集配置文件路径,需要按照格式提供数据集信息epochs=config["epochs"],imgsz=config["imgsz"],batch_size=config["batch_size"],workers=config["workers"],device=config["device"],project=config["project"],name=config["name"]
)

data='coco.yaml' 是 Ultralytics 提供的 COCO 数据集配置文件路径,如果你有自定义的数据集,可以自己按照格式创建一个 .yaml 文件。


步骤 5:验证训练结果

# 在训练完成后对模型进行验证
metrics = model.val()
print(metrics)

val() 方法返回的是在验证数据上的检测指标(如 mAP、Recall 等)。


三、优化器的选择和使用(以 AdamW 为例)

YOLOv8 默认使用的是 AdamW 优化器,它是一个改进的 Adam 优化算法,支持 L2 正则化。

示例:

# 如果你需要手动控制训练过程,也可以直接使用 torch.optim.AdamWfrom torch import optim, nn# 假设你已经有了一个网络模型 model
model = Model()  # 自己定义的神经网络结构optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.BCELoss()for epoch in range(config["epochs"]):for batch_data, batch_labels in train_loader:outputs = model(batch_data)loss = criterion(outputs, batch_labels)optimizer.zero_grad()loss.backward()optimizer.step()

四、提前停止策略(Early Stopping)的实现

在模型训练过程中,如果验证集上的损失连续多轮没有下降,则可以考虑提前结束训练。

在这里插入图片描述

from sklearn.model_selection import train_test_split
import torch
from ultralytics import YOLO, EarlyStoppingmodel = YOLO("yolov8n.pt")early_stopper = EarlyStopping(patience=10)  # 设置为10轮没有改进则提前停止for epoch in range(100):model.train()train_loss = model.train()  # 训练val_loss = model.val()      # 验证if early_stopper(val_loss):print("Early stopping triggered.")break

五、如何确定训练纪元数(Epochs)

  • 如果是目标检测任务,建议设置为 100 ~ 500
  • 初始可设为 30,然后观察验证损失是否在下降,再适当增加。
  • 每次增加 20 或 50 轮左右即可。

注意:如果模型过拟合,可以通过早停、数据增强(如 CutMix)等方法缓解。


六、训练过程的监控与记录

输出日志:

YOLOv8 的训练会实时输出以下内容:

  • 每个 epoch 中损失的变化。
  • mAP(mean Average Precision),衡量检测准确率的一个指标。
  • 训练用时和每个 epoch 速度。

可视化工具:

使用 TensorBoard 来查看模型的训练过程,可以通过以下方式启用它:

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter('runs/train')for i, data in enumerate(train_loader):outputs = model(data)writer.add_scalar("Loss", loss.item(), i)  # 记录损失值

然后使用命令启动 TensorBoard:

tensorboard --logdir runs/train

📌 小结

模块内容
模型选择使用 YOLOv8 的预训练模型 yolov8n.pt
优化器AdamW 是目前最常用的选择,支持权重衰减(防止过拟合)
提前停止若验证损失连续10轮无改进,则终止训练以节省资源
训练纪元数初始设置为50~100,再根据结果增加即可
数据处理需要自定义 data.yaml 文件来描述训练和测试集的路径、类别等信息

http://www.dtcms.com/a/274726.html

相关文章:

  • 「GRPO训练参数详解:理解Batch构成与生成数量的关系」
  • 如何使用数字化动态水印对教育视频进行加密?
  • 学习日记-spring-day46-7.11
  • 【Linux-云原生-笔记】系统引导修复(grub、bios、内核、系统初始化等)
  • USB数据丢包真相:为什么log打印会导致高频USB数据丢包?
  • 数据库系统的基础知识(三)
  • Logback.xml配置详解与实战指南
  • 目标检测中的NMS算法详解
  • Java基础-String常用的方法
  • 关于MySql索引,你需要知道!!!
  • CompletableFuture 详解
  • Java教程:JavaWeb ---MySQL高级
  • Flutter 箭头语法
  • 【世纪龙科技】新能源汽车结构原理教学软件-几何G6
  • OpenCV多种图像哈希算法的实现比较
  • 中国国际会议会展中心模块化解决方案的技术经济分析报告
  • C++中的智能指针(1):unique_ptr
  • 在Python项目中统一处理日志
  • javaweb之相关jar包和前端包下载。
  • AGX Xavier 搭建360环视教程【一、先确认方案】
  • Kafka——应该选择哪种Kafka?
  • 三种方法批量填充订单表中的空白单元格--python,excel vba,excel
  • 【深度学习新浪潮】图像生成有哪些最新进展?
  • linux-base-end
  • 从《哪吒 2》看个人IP的破局之道|创客匠人
  • NodeJs后端常用三方库汇总
  • css——width: fit-content 宽度、自适应
  • lesson10:Python的元组
  • UI前端与数字孪生结合实践探索:智慧农业的精准灌溉系统
  • FastAPI + SQLAlchemy (异步版)连接数据库时,对数据进行加密