利用 YOLOv5-7.0 和 ByteTrack 实现多目标跟踪 — Python Demo 详解
本文介绍如何基于 YOLOv5-7.0 版本及 ByteTrack 开源多目标跟踪算法,快速搭建一个完整的车辆/行人跟踪演示系统。
我们将讲解项目的文件结构、核心代码功能,以及如何一步步运行 Demo 完成视频中的目标检测和跟踪。
目录结构
yolov5-7.0/
├── ByteTrack/ # ByteTrack 算法核心实现代码
│ ├── basetrack.py # 跟踪基类
│ ├── byte_tracker.py # ByteTrack 主跟踪逻辑实现
│ ├── kalman_filter.py # 卡尔曼滤波器
│ ├── matching.py # 匹配算法(匈牙利等)
│ ├── timer.py # 计时工具
│ ├── visualize.py # 跟踪结果可视化
│ └── __init__.py # Python 包标识文件
├── example/
│ └── bytetrack.py # 本文核心 Demo 脚本:YOLOv5 + ByteTrack 演示
├── models/
│ └── experimental.py # YOLOv5 模型加载和实验功能
├── utils/
│ ├── augmentations.py # 数据增强函数(letterbox 等)
│ ├── general.py # YOLOv5 常用工具函数(NMS 等)
│ └── ...
├── weights/
│ └── best.pt # 训练好的 YOLOv5 模型权重
├── videos/
│ └── palace.mp4 # 输入测试视频
└── output/
└── result.mp4 # 追踪结果输出视频(运行时生成)
其中ByteTrack文件夹下的所有文件来自于ByteTrack官方的代码复制过来
代码详解及完整示例
下面是整合 YOLOv5 与 ByteTrack 跟踪的 Python Demo 脚本,文件路径为 example/bytetrack.py
。
import os
import cv2
import sys
import torch
import torchvision
from pathlib import Pathsys.path.append('/home/only/company/yolov5-7.0') # 添加含有 models/ 的路径
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'ByteTrack'))from ByteTrack.timer import Timer
from models.experimental import attempt_load
from utils.augmentations import letterbox
from utils.general import non_max_suppression
from byte_tracker import BYTETracker
from ByteTrack.visualize import plot_trackingMODEL_PATH = "weights/best.pt"
INPUT_VIDEO = "videos/palace.mp4"
OUTPUT_VIDEO = "output/result.mp4"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")CONF_THRESHOLD = 0.25
NMS_THRESHOLD = 0.45class Args:def __init__(self):self.track_thresh = 0.5self.track_buffer = 30self.match_thresh = 0.8self.mot20 = Falseself.aspect_ratio_thresh = 1.6self.min_box_area = 10def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):"""将预测框从 letterbox 图像坐标映射回原始图像坐标"""if ratio_pad is None:# 计算缩放比例和paddinggain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])pad = ((img1_shape[1] - img0_shape[1] * gain) / 2,(img1_shape[0] - img0_shape[0] * gain) / 2) # x, y paddingelse:gain = ratio_pad[0]pad = ratio_pad[1]coords[:, [0, 2]] -= pad[0] # x paddingcoords[:, [1, 3]] -= pad[1] # y paddingcoords[:, :4] /= gaincoords[:, :4] = coords[:, :4].clamp(min=0)return coordsdef postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):box_corner = prediction.new(prediction.shape)box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2prediction[:, :, :4] = box_corner[:, :, :4]output = [None for _ in range(len(prediction))]for i, image_pred in enumerate(prediction):# print(image_pred.shape)# If none are remaining => process next imageif not image_pred.size(0):continue# Get score and class with highest confidenceclass_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)detections = detections[conf_mask]if not detections.size(0):continuenms_out_index = torchvision.ops.batched_nms(detections[:, :4],detections[:, 4] * detections[:, 5],detections[:, 6],nms_thre,)detections = detections[nms_out_index]if output[i] is None:output[i] = detectionselse:output[i] = torch.cat((output[i], detections))return outputif __name__ == "__main__":# load modelmodel = attempt_load(weights=MODEL_PATH, device=DEVICE)model.eval()cap = cv2.VideoCapture(INPUT_VIDEO)# 获取视频参数width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))fps = cap.get(cv2.CAP_PROP_FPS)# 创建视频写入对象os.makedirs(os.path.dirname(OUTPUT_VIDEO), exist_ok=True)fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或 'XVID'out_video = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))frame_id = 0timer = Timer()args = Args()tracker = BYTETracker(args, frame_rate=30)while True:ret, frame = cap.read()if not ret:breakimg_info = {"id": 0}img_info["height"] = frame.shape[0]img_info["width"] = frame.shape[1]img_info["raw_img"] = frame# 图像预处理img, ratio, (dw, dh) = letterbox(frame, new_shape=(640, 640)) # resize with paddingimg_info["ratio"] = ratioimg = img[:, :, ::-1].transpose(2, 0, 1).copy() # BGR to RGB, to 3xHxWimg = torch.from_numpy(img).float().div(255.0).unsqueeze(0).to(DEVICE) # Normalizeheight, width = img.shape[2], img.shape[3]# 推理with torch.no_grad():pred = model(img)[0]detections = postprocess(pred, 5, CONF_THRESHOLD, NMS_THRESHOLD)[0]online_targets = tracker.update(detections, [img_info['height'], img_info['width']], (height, width))online_tlwhs = []online_ids = []online_scores = []for t in online_targets:tlwh = t.tlwhtid = t.track_idvertical = tlwh[2] / tlwh[3] > args.aspect_ratio_threshif tlwh[2] * tlwh[3] > args.min_box_area and not vertical:online_tlwhs.append(tlwh)online_ids.append(tid)online_scores.append(t.score)timer.toc()online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time)out_video.write(online_im)frame_id += 1cap.release()out_video.release()cv2.destroyAllWindows()print("视频保存成功:", OUTPUT_VIDEO)
使用说明
- 环境依赖:
确保安装如下 Python 包(版本可根据需要调整):
pip install torch torchvision opencv-python
- 准备模型权重
将训练好的 YOLOv5 权重放置于 weights/best.pt
路径。
- 准备测试视频
将待测试视频放置于 videos/palace.mp4
。
- 运行 Demo
python example/bytetrack.py
- 查看结果
处理结束后,结果视频文件将保存在 output/result.mp4
,你可以用播放器查看带有跟踪 ID 的目标检测视频。