当前位置: 首页 > news >正文

利用 YOLOv5-7.0 和 ByteTrack 实现多目标跟踪 — Python Demo 详解

本文介绍如何基于 YOLOv5-7.0 版本及 ByteTrack 开源多目标跟踪算法,快速搭建一个完整的车辆/行人跟踪演示系统。
我们将讲解项目的文件结构、核心代码功能,以及如何一步步运行 Demo 完成视频中的目标检测和跟踪。


目录结构


yolov5-7.0/
├── ByteTrack/                   # ByteTrack 算法核心实现代码
│   ├── basetrack.py            # 跟踪基类
│   ├── byte_tracker.py         # ByteTrack 主跟踪逻辑实现
│   ├── kalman_filter.py        # 卡尔曼滤波器
│   ├── matching.py             # 匹配算法(匈牙利等)
│   ├── timer.py                # 计时工具
│   ├── visualize.py            # 跟踪结果可视化
│   └── __init__.py             # Python 包标识文件
├── example/
│   └── bytetrack.py            # 本文核心 Demo 脚本:YOLOv5 + ByteTrack 演示
├── models/
│   └── experimental.py         # YOLOv5 模型加载和实验功能
├── utils/
│   ├── augmentations.py        # 数据增强函数(letterbox 等)
│   ├── general.py              # YOLOv5 常用工具函数(NMS 等)
│   └── ...
├── weights/
│   └── best.pt                 # 训练好的 YOLOv5 模型权重
├── videos/
│   └── palace.mp4              # 输入测试视频
└── output/
└── result.mp4              # 追踪结果输出视频(运行时生成)

其中ByteTrack文件夹下的所有文件来自于ByteTrack官方的代码复制过来

代码详解及完整示例

下面是整合 YOLOv5 与 ByteTrack 跟踪的 Python Demo 脚本,文件路径为 example/bytetrack.py

import os
import cv2
import sys
import torch
import torchvision
from pathlib import Pathsys.path.append('/home/only/company/yolov5-7.0')  # 添加含有 models/ 的路径
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'ByteTrack'))from ByteTrack.timer import Timer
from models.experimental import attempt_load
from utils.augmentations import letterbox
from utils.general import non_max_suppression
from byte_tracker import BYTETracker
from ByteTrack.visualize import plot_trackingMODEL_PATH = "weights/best.pt"
INPUT_VIDEO = "videos/palace.mp4"
OUTPUT_VIDEO = "output/result.mp4"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")CONF_THRESHOLD = 0.25
NMS_THRESHOLD = 0.45class Args:def __init__(self):self.track_thresh = 0.5self.track_buffer = 30self.match_thresh = 0.8self.mot20 = Falseself.aspect_ratio_thresh = 1.6self.min_box_area = 10def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):"""将预测框从 letterbox 图像坐标映射回原始图像坐标"""if ratio_pad is None:# 计算缩放比例和paddinggain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])pad = ((img1_shape[1] - img0_shape[1] * gain) / 2,(img1_shape[0] - img0_shape[0] * gain) / 2)  # x, y paddingelse:gain = ratio_pad[0]pad = ratio_pad[1]coords[:, [0, 2]] -= pad[0]  # x paddingcoords[:, [1, 3]] -= pad[1]  # y paddingcoords[:, :4] /= gaincoords[:, :4] = coords[:, :4].clamp(min=0)return coordsdef postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):box_corner = prediction.new(prediction.shape)box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2prediction[:, :, :4] = box_corner[:, :, :4]output = [None for _ in range(len(prediction))]for i, image_pred in enumerate(prediction):# print(image_pred.shape)# If none are remaining => process next imageif not image_pred.size(0):continue# Get score and class with highest confidenceclass_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)detections = detections[conf_mask]if not detections.size(0):continuenms_out_index = torchvision.ops.batched_nms(detections[:, :4],detections[:, 4] * detections[:, 5],detections[:, 6],nms_thre,)detections = detections[nms_out_index]if output[i] is None:output[i] = detectionselse:output[i] = torch.cat((output[i], detections))return outputif __name__ == "__main__":# load modelmodel = attempt_load(weights=MODEL_PATH, device=DEVICE)model.eval()cap = cv2.VideoCapture(INPUT_VIDEO)# 获取视频参数width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))fps = cap.get(cv2.CAP_PROP_FPS)# 创建视频写入对象os.makedirs(os.path.dirname(OUTPUT_VIDEO), exist_ok=True)fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 或 'XVID'out_video = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))frame_id = 0timer = Timer()args = Args()tracker = BYTETracker(args, frame_rate=30)while True:ret, frame = cap.read()if not ret:breakimg_info = {"id": 0}img_info["height"] = frame.shape[0]img_info["width"] = frame.shape[1]img_info["raw_img"] = frame# 图像预处理img, ratio, (dw, dh) = letterbox(frame, new_shape=(640, 640))  # resize with paddingimg_info["ratio"] = ratioimg = img[:, :, ::-1].transpose(2, 0, 1).copy()  # BGR to RGB, to 3xHxWimg = torch.from_numpy(img).float().div(255.0).unsqueeze(0).to(DEVICE)  # Normalizeheight, width = img.shape[2], img.shape[3]# 推理with torch.no_grad():pred = model(img)[0]detections = postprocess(pred, 5, CONF_THRESHOLD, NMS_THRESHOLD)[0]online_targets = tracker.update(detections, [img_info['height'], img_info['width']], (height, width))online_tlwhs = []online_ids = []online_scores = []for t in online_targets:tlwh = t.tlwhtid = t.track_idvertical = tlwh[2] / tlwh[3] > args.aspect_ratio_threshif tlwh[2] * tlwh[3] > args.min_box_area and not vertical:online_tlwhs.append(tlwh)online_ids.append(tid)online_scores.append(t.score)timer.toc()online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time)out_video.write(online_im)frame_id += 1cap.release()out_video.release()cv2.destroyAllWindows()print("视频保存成功:", OUTPUT_VIDEO)

使用说明

  1. 环境依赖

确保安装如下 Python 包(版本可根据需要调整):

pip install torch torchvision opencv-python
  1. 准备模型权重

将训练好的 YOLOv5 权重放置于 weights/best.pt 路径。

  1. 准备测试视频

将待测试视频放置于 videos/palace.mp4

  1. 运行 Demo
python example/bytetrack.py
  1. 查看结果

处理结束后,结果视频文件将保存在 output/result.mp4,你可以用播放器查看带有跟踪 ID 的目标检测视频。


相关文章:

  • 最优秀的佛山网站建设百度热门关键词排名
  • wordpress主题自媒体一号seo关键词快速提升软件官网
  • 公司合法网站域名怎么注册g3云推广
  • 凡科是免费做网站吗seo的培训班
  • 广州网站制作公司联系方式站长之家统计
  • 做3dmax的网站推广网页
  • 降低90%推理成本:腾讯混元+云函数动态扩缩容策略详解
  • c++面向对象编程
  • 【Java开发日记】详细地讲解一下如何保证线程安全性呢?
  • 鸿蒙原子化服务与元服务:轻量化服务的未来之路
  • 湖北理元理律师事务所:科学债务优化如何守护民生底线
  • 提示工程入门指南:如何有效地与大语言模型交互
  • Python Selenium 忽略证书错误
  • MongoDB入门学习(含JAVA客户端)
  • Postman接口测试入门
  • 数据结构进阶 - 第九章 排序
  • 使用 Python 自动化文件获取:从 FTP 到 API 的全面指南
  • 【Cursor 】Cursor 解析江科大倒立摆PID工程源码《00-PID综合测试程序-V1.1》《03-增量式PID定速控制》(Doxygen注释风格)
  • 同步互斥与通信-有缺陷的同步示例FreeRTOS笔记
  • CVPR-2025 | 缩小仿真与现实差距的具身导航新突破!Vid2Sim:从视频到逼真交互式仿真环境的城市导航
  • 【FAQ】HarmonyOS SDK 闭源开放能力 —Account Kit (6)
  • el-select封装下拉加载组件
  • 【Linux学习笔记】进程通信之消息队列和信号量
  • Oracle数据库捕获造成死锁的SQL语句
  • 采集文章+原创AI处理+发布网站详细教程
  • 开疆智能CCLinkIE转ModbusTCP网关连接PCA3200电能表配置案例