使用python进行船舶轨迹跟踪
一、系统概述
该系统基于 YOLOv8 深度学习模型和计算机视觉技术,实现对视频或摄像头画面中的船舶进行实时检测、跟踪,并计算船舶航向。支持透视变换校准(鸟瞰图显示)、多目标跟踪、轨迹存储及视频录制功能,适用于港口监控、航道管理、船舶行为分析等场景。
二、依赖库
python
运行
import cv2 # 计算机视觉处理(OpenCV库)
import numpy as np # 数值计算
import time # 时间处理
import os # 文件与目录操作
from datetime import datetime # 日期时间处理
from ultralytics import YOLO # YOLOv8深度学习模型
三、类定义:ShipTracker
3.1 构造函数 __init__
功能
初始化船舶跟踪器,配置视频源、输出参数、YOLOv8 模型及跟踪参数。
参数说明
参数名 | 类型 | 默认值 | 描述 |
---|---|---|---|
video_source | int/str | 0 | 视频源(0 为默认摄像头,或指定视频文件路径) |
save_video | bool | False | 是否保存处理后的视频 |
show_warped | bool | True | 是否显示透视变换后的鸟瞰图 |
model_path | str | yolov8n.pt | YOLOv8 模型路径(默认为 COCO 预训练的小模型) |
内部属性
- 视频源与基础参数:
cap
:视频捕获对象(cv2.VideoCapture
实例)frame_width
/frame_height
:视频帧宽高fps
:帧率
- 输出配置:
output_folder
:输出文件夹(默认output
)out
:视频写入对象(cv2.VideoWriter
实例,仅当save_video=True
时创建)
- 深度学习模型:
model
:YOLOv8 模型实例ship_class_id
:船舶类别 ID(COCO 数据集中为8
)
- 检测参数:
confidence_threshold
:置信度阈值(过滤低置信度检测结果)nms_threshold
:非极大值抑制阈值(过滤重叠检测框)
- 跟踪参数:
trajectories
:存储轨迹的字典(键为船舶 ID,值为轨迹信息)max_disappeared_frames
:允许目标消失的最大帧数(超过则删除轨迹)max_distance
:轨迹匹配的最大距离(像素)min_trajectory_points
:计算航向所需的最小轨迹点数
- 透视变换:
perspective_transform
:透视变换矩阵(校准后生成)warped_width
/warped_height
:鸟瞰图尺寸(默认 800×800)
3.2 方法列表
3.2.1 calibrate_perspective()
- 功能:通过鼠标点击选择 4 个点,校准透视变换矩阵,生成鸟瞰图。
- 操作说明:
- 显示视频第一帧,按顺序点击左上、右上、右下、左下四个点,形成矩形区域。
- 按
q
键退出校准。
- 返回值:
bool
(True
为校准成功,False
为取消或失败)
3.2.2 detect_ships(frame)
- 功能:使用 YOLOv8 模型检测图像中的船舶。
- 输入:
frame
(BGR 格式图像) - 处理流程:
- 调用 YOLOv8 模型进行预测,指定类别为船舶(ID=8)。
- 过滤低于置信度阈值的检测结果。
- 应用非极大值抑制(NMS)消除重叠框。
- 返回值:船舶检测结果列表(每个元素为字典,包含
bbox
、center
、confidence
、class
)
3.2.3 calculate_heading(positions)
- 功能:根据船舶轨迹点计算航向角(0-360 度,0 为正北,顺时针增加)。
- 输入:
positions
(轨迹点列表,每个点为(x, y)
坐标) - 算法逻辑:
- 选择最近的
min_trajectory_points
个点。 - 使用最小二乘法拟合直线。
- 计算直线角度并转换为航向角。
- 选择最近的
- 返回值:航向角(浮点数,单位为度)或
None
(轨迹点不足时)
3.2.4 track_ships(detected_ships)
- 功能:根据检测结果更新船舶轨迹。
- 输入:
detected_ships
(detect_ships
返回的船舶列表) - 算法逻辑:
- 计算现有轨迹与新检测的匹配距离(欧氏距离),优先匹配近距离目标。
- 未匹配的轨迹:若连续消失超过
max_disappeared_frames
,则删除。 - 未匹配的检测:创建新轨迹,分配唯一 ID。
3.2.5 draw_results(frame, ships)
- 功能:在图像上绘制检测框、轨迹、航向及统计信息,支持鸟瞰图显示。
- 输入:
frame
:原始帧ships
:检测到的船舶列表
- 输出:绘制后的结果图像(若
show_warped=True
,则为原始帧与鸟瞰图的横向拼接图)
3.2.6 save_trajectories()
- 功能:将当前所有轨迹数据保存到文本文件,包含 ID、起始时间、轨迹点坐标及平均航向。
- 存储路径:
output_folder/ship_trajectories_时间戳.txt
3.2.7 run()
- 功能:运行跟踪主循环,处理视频流并实时显示结果。
- 操作说明:
- 按
q
键退出程序。 - 按
s
键保存当前轨迹数据。
- 按
- 流程:
- 调用
calibrate_perspective()
进行透视校准(可选)。 - 逐帧读取视频,检测、跟踪船舶,绘制结果。
- 释放资源并关闭窗口。
- 调用
四、主程序入口
python
运行
if __name__ == "__main__":tracker = ShipTracker(video_source=0, # 0为摄像头,或指定视频文件路径(如"ship_video.mp4")save_video=True, # 启用视频录制show_warped=True, # 显示鸟瞰图model_path="yolov8n.pt" # YOLOv8模型路径)tracker.run()
五、使用说明
5.1 环境配置
- 安装依赖库:
bash
pip install opencv-python numpy ultralytics
- 下载 YOLOv8 模型(如
yolov8n.pt
),并指定正确路径。
5.2 透视校准操作
- 运行程序后,会弹出窗口提示选择 4 个点。
- 按顺序点击视频中的矩形区域四角(如水面区域),生成鸟瞰图。
- 校准完成后,右侧会显示鸟瞰图中的船舶轨迹。
5.3 输出文件
- 视频文件:若
save_video=True
,生成output/ship_tracking_时间戳.avi
。 - 轨迹文件:按
s
键生成output/ship_trajectories_时间戳.txt
,包含各 ID 的坐标序列和航向信息。
六、参数调整建议
参数名 | 作用 | 调整场景 |
---|---|---|
confidence_threshold | 过滤低置信度的船舶检测结果 | 目标较小或环境复杂时调高 |
nms_threshold | 控制非极大值抑制的严格程度 | 船舶密集时调低 |
max_disappeared_frames | 目标消失后保留轨迹的帧数 | 船舶被遮挡时间较长时调大 |
max_distance | 轨迹匹配的最大允许距离 | 船舶运动速度快时调大 |
min_trajectory_points | 计算航向所需的最小轨迹点数 | 航向计算不稳定时调大 |
七、注意事项
- YOLOv8 模型需要一定计算资源,建议在 GPU 环境下运行以提高帧率。
- 透视校准的四点应选择实际场景中的矩形区域(如水面边界),以确保鸟瞰图坐标准确。
- 船舶航向计算基于轨迹拟合,需要足够的轨迹点才能保证准确性。
- 若视频帧率较低,可尝试降低
warped_width
或关闭show_warped
以减少计算量。
完整代码
import cv2
import numpy as np
import time
import os
from datetime import datetime
from ultralytics import YOLOclass ShipTracker:def __init__(self, video_source=0, save_video=False, show_warped=True, model_path="D:/06_Python/20250321_Deep_Learning/yolov8n.pt"):"""初始化船舶跟踪器"""# 视频源设置self.video_source = video_sourceself.cap = cv2.VideoCapture(video_source)if not self.cap.isOpened():raise ValueError("无法打开视频源", video_source)# 获取视频的宽度、高度和帧率self.frame_width = int(self.cap.get(3))self.frame_height = int(self.cap.get(4))self.fps = self.cap.get(cv2.CAP_PROP_FPS)# 输出设置self.save_video = save_videoself.output_folder = "output"self.show_warped = show_warped# 创建输出文件夹if not os.path.exists(self.output_folder):os.makedirs(self.output_folder)# 加载YOLOv8模型self.model = YOLO(model_path)self.ship_class_id = 8 # COCO数据集中船的类别ID# 船舶检测参数self.confidence_threshold = 0.5self.nms_threshold = 0.4# 轨迹存储self.trajectories = {} # 存储每艘船的轨迹self.next_ship_id = 1 # 下一个可用的船舶IDself.max_disappeared_frames = 15 # 最大消失帧数self.max_distance = 150 # 最大匹配距离self.min_trajectory_points = 5 # 计算航向所需的最小轨迹点# 透视变换参数self.perspective_transform = Noneself.warped_width = 800self.warped_height = 800# 录制设置self.out = Noneif save_video:timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")output_path = os.path.join(self.output_folder, f"ship_tracking_{timestamp}.avi")fourcc = cv2.VideoWriter_fourcc(*'XVID')self.out = cv2.VideoWriter(output_path, fourcc, self.fps, (self.frame_width, self.frame_height))def calibrate_perspective(self):"""校准透视变换,创建鸟瞰图"""print("请在图像中选择4个点,形成一个矩形区域,用于透视变换")print("按顺序点击:左上、右上、右下、左下")# 读取一帧用于选择点ret, frame = self.cap.read()if not ret:print("无法读取视频帧")return False# 创建窗口并设置鼠标回调cv2.namedWindow("选择透视变换点 (按 'q' 退出)")points = []def click_event(event, x, y, flags, param):if event == cv2.EVENT_LBUTTONDOWN:points.append((x, y))cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)cv2.imshow("选择透视变换点 (按 'q' 退出)", frame)cv2.setMouseCallback("选择透视变换点 (按 'q' 退出)", click_event)# 显示图像并等待点击cv2.imshow("选择透视变换点 (按 'q' 退出)", frame)while len(points) < 4:key = cv2.waitKey(1) & 0xFFif key == ord('q'):cv2.destroyAllWindows()return Falsecv2.destroyAllWindows()# 定义目标矩形src = np.float32(points)dst = np.float32([[0, 0],[self.warped_width, 0],[self.warped_width, self.warped_height],[0, self.warped_height]])# 计算透视变换矩阵self.perspective_transform = cv2.getPerspectiveTransform(src, dst)return Truedef detect_ships(self, frame):"""使用YOLOv8检测图像中的船舶"""# 运行模型预测results = self.model(frame, classes=self.ship_class_id, conf=self.confidence_threshold, iou=self.nms_threshold)# 处理检测结果ships = []for result in results:boxes = result.boxes.cpu().numpy()for box in boxes:x1, y1, x2, y2 = box.xyxy[0].astype(int)conf = box.conf[0]cls = int(box.cls[0])# 计算边界框中心点和宽高w, h = x2 - x1, y2 - y1center = (int(x1 + w/2), int(y1 + h/2))ships.append({'bbox': (x1, y1, w, h),'center': center,'confidence': conf,'class': cls})return shipsdef calculate_heading(self, positions):"""根据轨迹点计算船舶航向"""if len(positions) < self.min_trajectory_points:return None# 选择最近的几个点recent_points = positions[-self.min_trajectory_points:]# 拟合直线x = np.array([p[0] for p in recent_points])y = np.array([p[1] for p in recent_points])# 计算直线拟合A = np.vstack([x, np.ones(len(x))]).Tm, c = np.linalg.lstsq(A, y, rcond=None)[0]# 计算角度(弧度)angle = np.arctan2(1, m) # y轴向下为正# 转换为角度(0-360度,0度为正北,顺时针增加)heading = (np.degrees(angle) + 90) % 360return headingdef track_ships(self, detected_ships):"""跟踪检测到的船舶"""# 计算当前检测点与现有轨迹的距离unmatched_tracks = list(self.trajectories.keys())unmatched_detections = list(range(len(detected_ships)))matches = []# 计算所有可能的匹配for track_id in self.trajectories:trajectory = self.trajectories[track_id]last_position = trajectory['positions'][-1]min_distance = float('inf')min_index = -1for i, ship in enumerate(detected_ships):if i in unmatched_detections:distance = np.sqrt((last_position[0] - ship['center'][0])**2 + (last_position[1] - ship['center'][1])**2)if distance < min_distance and distance < self.max_distance:min_distance = distancemin_index = i# 如果找到匹配if min_index != -1:matches.append((track_id, min_index, min_distance))# 按距离排序,优先处理距离近的匹配matches.sort(key=lambda x: x[2])# 应用匹配for match in matches:track_id, detection_index, _ = matchif track_id in unmatched_tracks and detection_index in unmatched_detections:# 更新轨迹self.trajectories[track_id]['positions'].append(detected_ships[detection_index]['center'])self.trajectories[track_id]['last_seen'] = 0self.trajectories[track_id]['bbox'] = detected_ships[detection_index]['bbox']self.trajectories[track_id]['confidence'] = detected_ships[detection_index]['confidence']# 从待匹配列表中移除unmatched_tracks.remove(track_id)unmatched_detections.remove(detection_index)# 处理未匹配的轨迹for track_id in unmatched_tracks:self.trajectories[track_id]['last_seen'] += 1if self.trajectories[track_id]['last_seen'] > self.max_disappeared_frames:del self.trajectories[track_id]# 处理未匹配的检测结果for detection_index in unmatched_detections:# 创建新轨迹self.trajectories[self.next_ship_id] = {'positions': [detected_ships[detection_index]['center']],'last_seen': 0,'bbox': detected_ships[detection_index]['bbox'],'confidence': detected_ships[detection_index]['confidence'],'start_time': time.time()}self.next_ship_id += 1def draw_results(self, frame, ships):"""在图像上绘制检测和跟踪结果"""output = frame.copy()# 绘制检测到的船舶for ship in ships:x, y, w, h = ship['bbox']cv2.rectangle(output, (x, y), (x + w, y + h), (0, 255, 0), 2)cv2.circle(output, ship['center'], 5, (0, 0, 255), -1)cv2.putText(output, f"Conf: {ship['confidence']:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)# 绘制轨迹和航向for track_id, trajectory in self.trajectories.items():positions = trajectory['positions']# 绘制轨迹线for i in range(1, len(positions)):cv2.line(output, positions[i-1], positions[i], (255, 0, 0), 2)# 绘制轨迹点for pos in positions:cv2.circle(output, pos, 3, (255, 0, 0), -1)# 计算并绘制航向heading = self.calculate_heading(positions)if heading is not None:center = positions[-1]# 计算航向线终点heading_rad = np.radians(heading - 90) # 转换为OpenCV坐标系length = 50end_point = (int(center[0] + length * np.cos(heading_rad)),int(center[1] + length * np.sin(heading_rad)))# 绘制航向线cv2.arrowedLine(output, center, end_point, (0, 255, 255), 3, tipLength=0.3)# 显示航向角度cv2.putText(output, f"Heading: {heading:.1f}°", (center[0] + 10, center[1] - 40),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)# 绘制ID和轨迹长度if len(positions) > 0:last_pos = positions[-1]cv2.putText(output, f"ID: {track_id}", (last_pos[0] + 10, last_pos[1] - 20),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)cv2.putText(output, f"Points: {len(positions)}", (last_pos[0] + 10, last_pos[1]),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)# 显示统计信息cv2.putText(output, f"Ships: {len(self.trajectories)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)cv2.putText(output, f"FPS: {int(self.fps)}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)# 创建结果显示窗口if self.show_warped and self.perspective_transform is not None:# 创建鸟瞰图warped = cv2.warpPerspective(output, self.perspective_transform, (self.warped_width, self.warped_height))# 合并显示# 调整图像大小使高度一致if output.shape[0] != warped.shape[0]:scale = output.shape[0] / warped.shape[0]new_width = int(warped.shape[1] * scale)warped = cv2.resize(warped, (new_width, output.shape[0]))combined = np.hstack((output, warped))return combinedreturn outputdef save_trajectories(self):"""保存轨迹数据到文件"""timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")output_path = os.path.join(self.output_folder, f"ship_trajectories_{timestamp}.txt")with open(output_path, 'w') as f:f.write("Ship Trajectories\n")f.write(f"Recorded on: {datetime.now()}\n\n")for track_id, trajectory in self.trajectories.items():f.write(f"Ship ID: {track_id}\n")f.write(f"Start Time: {time.ctime(trajectory['start_time'])}\n")f.write(f"Duration: {time.time() - trajectory['start_time']:.2f} seconds\n")f.write(f"Trajectory Points: {len(trajectory['positions'])}\n")# 计算平均航向heading = self.calculate_heading(trajectory['positions'])if heading is not None:f.write(f"Average Heading: {heading:.1f}°\n")f.write("Positions:\n")for pos in trajectory['positions']:f.write(f" ({pos[0]}, {pos[1]})\n")f.write("\n")print(f"轨迹数据已保存到: {output_path}")def run(self):"""运行船舶跟踪系统"""# 首先进行透视校准if not self.calibrate_perspective():print("透视校准失败,使用原始视角")print("开始船舶跟踪...")print("按 'q' 退出,按 's' 保存轨迹数据")frame_count = 0start_time = time.time()while True:ret, frame = self.cap.read()if not ret:break# 计算实际帧率frame_count += 1if frame_count % 10 == 0:elapsed_time = time.time() - start_timeself.fps = frame_count / elapsed_time# 检测船舶ships = self.detect_ships(frame)# 跟踪船舶self.track_ships(ships)# 绘制结果result = self.draw_results(frame, ships)# 保存视频if self.save_video:self.out.write(result)# 显示结果cv2.imshow("船舶轨迹跟踪系统 (按 'q' 退出,按 's' 保存轨迹)", result)# 按键处理key = cv2.waitKey(1) & 0xFFif key == ord('q'):breakelif key == ord('s'):self.save_trajectories()# 释放资源self.cap.release()if self.out:self.out.release()cv2.destroyAllWindows()print("船舶跟踪系统已关闭")# 主程序入口
if __name__ == "__main__":# 创建船舶跟踪器实例tracker = ShipTracker(video_source=0, # 0表示默认摄像头,也可以指定视频文件路径save_video=True, # 是否保存视频show_warped=True, # 是否显示鸟瞰图model_path="D:/06_Python/20250321_Deep_Learning/yolov8n.pt" # YOLOv8模型路径)# 运行跟踪器tracker.run()