a sort.py demo
这份代码展示了如何使用 sort.py。注意,此处,我将文件名改为 my_sort.py。
你并不能直接 copy 使用,因为环境,包,还有模型。
此处使用 SSD-MobileNetv2 进行物体检测,将结果传入以 np 数组的形式传入sort 模块,经过处理,以 np.empty((0, 5))的格式传出,在绘画模块,提取信息,标识矩形框和物体ID。
#!/home/ncut/miniconda3/envs/tf/bin/python
# -*- coding: utf-8 -*-
import rospy
import tensorflow as tf
import cv2
import numpy as np
import time
from sensor_msgs.msg import Image
from sensor_msgs.msg import CompressedImage
from cv_bridge import CvBridge, CvBridgeError
from my_sort import Sort # 确保 my_sort 模块在 Python 路径下
# --------------------- 模型推理模块 ---------------------
def load_model(model_dir):
"""
加载 TensorFlow SavedModel(例如 ssd-mobilenet-v2 或 efficientdet)
返回推理函数。
"""
model = tf.saved_model.load(model_dir)
infer = model.signatures["serving_default"]
return infer
def preprocess_frame(frame):
"""
预处理输入图像:
- 将 BGR 转换为 RGB(模型输入要求)
- 将图像 resize 为 320x320,并扩展 batch 维度
返回:
input_tensor: 模型输入 tensor
width: 原图宽度
height: 原图高度
"""
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized_frame = cv2.resize(frame, (320, 320), interpolation=cv2.INTER_AREA)
input_tensor = tf.convert_to_tensor(resized_frame, dtype=tf.uint8)
input_tensor = tf.expand_dims(input_tensor, 0) # 增加 batch 维度
height, width = frame.shape[:2]
return input_tensor, width, height
def run_inference(infer, input_tensor):
"""
利用推理函数执行模型预测,返回检测框和置信度。
"""
detections = infer(input_tensor)
num_detections = int(detections['num_detections'].numpy()[0])
boxes = detections['detection_boxes'].numpy()[0][:num_detections]
scores = detections['detection_scores'].numpy()[0][:num_detections]
return boxes, scores
def convert_detections_to_sort(boxes, scores, width, height, threshold=0.5):
"""
将检测结果(归一化坐标)转换为 SORT 跟踪器所需格式:[x1, y1, x2, y2, score]
"""
sort_inputs = []
for i in range(len(scores)):
if scores[i] < threshold:
continue
ymin, xmin, ymax, xmax = boxes[i]
x1 = int(xmin * width)
x2 = int(xmax * width)
y1 = int(ymin * height)
y2 = int(ymax * height)
sort_inputs.append([x1, y1, x2, y2, scores[i]])
return np.array(sort_inputs)
# --------------------- 跟踪与可视化模块 ---------------------
def draw_tracks(frame, tracks):
"""
在图像上绘制跟踪结果(边框和跟踪ID)。
"""
for track in tracks:
x1, y1, x2, y2, track_id = track.astype(int)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"ID: {track_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2)
return frame
# --------------------- ROS 图像订阅与处理模块 ---------------------
# 全局变量:模型推理函数、SORT 跟踪器、CvBridge 实例
infer = None
tracker = None
bridge = CvBridge()
def image_callback(msg):
"""
ROS 图像回调函数:
- 将 ROS 图像消息转换为 OpenCV 格式
- 进行模型推理和 SORT 跟踪
- 显示带跟踪结果的图像
"""
global infer, tracker, bridge
try:
cv_image = bridge.imgmsg_to_cv2(msg, "bgr8") # raw
#cv_image = bridge.compressed_imgmsg_to_cv2(msg,"bgr8") #compressed compressed_imgmsg_to_cv2
except CvBridgeError as e:
rospy.logerr("CvBridge 转换错误: %s", e)
return
# 图像预处理和模型推理
input_tensor, width, height = preprocess_frame(cv_image)
boxes, scores = run_inference(infer, input_tensor)
detections = convert_detections_to_sort(boxes, scores, width, height, threshold=0.5)
# 更新 SORT 跟踪器并绘制跟踪结果
tracks = tracker.update(detections)
tracked_frame = draw_tracks(cv_image.copy(), tracks)
# 显示带跟踪结果的图像
cv2.imshow("Tracking", tracked_frame)
cv2.waitKey(1)
def main():
global infer, tracker
# 初始化 ROS 节点
rospy.init_node("tracking_inference_node", anonymous=True)
# 加载模型
model_dir = "/home/ncut/models/ssd-mobilenet-v2" # 根据需要更新模型路径
time_before_load = time.time()
infer = load_model(model_dir)
time_after_load = time.time()
rospy.loginfo("模型加载耗时:%.2f 秒", time_after_load - time_before_load)
# 初始化 SORT 跟踪器
tracker = Sort(max_age=1, min_hits=3, iou_threshold=0.3)
# 订阅图像话题 below is a description of the car launch file astra /camera/rgb/image_raw/compressed
# to original photo, topic is /camera/rgb/image_raw average rate is 2Hz 4.02MB/s
# to compressed photo, topic is /camera/rgb/image_raw/compressed average rate is 30Hz 1.25MB/s
rospy.Subscriber("/camera/rgb/image_raw", Image, image_callback) # Image for raw, CompressedImage for compressed
rospy.loginfo("Tracking Inference Node 已启动,订阅话题:/camera/rgb/image_raw")
# ROS 循环等待消息
rospy.spin()
cv2.destroyAllWindows()
if __name__ == '__main__':
try:
main()
except rospy.ROSInterruptException:
pass