适配python3.9的 SORT算法
简单地更改了 sort.py 函数的接口,核心思想、处理操作并不改变。
源代码链接:https://github.com/abewley/sort
import os
import numpy as np
import glob
import time
import argparse
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment # 使用 SciPy 实现分配
np.random.seed(0)
def simple_test():
mot_tracker = Sort(max_age=1, min_hits=1, iou_threshold=0.3)
frame1_dets = np.array([
[100, 100, 150, 180, 0.9],
[200, 150, 240, 210, 0.85]
])
frame2_dets = np.array([
[110, 120, 170, 210, 0.88]
])
start_time = time.time()
trackers1 = mot_tracker.update(frame1_dets)
cycle_time = time.time() - start_time
print("Frame1 trackers:", trackers1, "time:", cycle_time)
# 第2帧
start_time = time.time()
trackers2 = mot_tracker.update(frame2_dets)
cycle_time = time.time() - start_time
print("Frame2 trackers:", trackers2, "time:", cycle_time)
def linear_assignment(cost_matrix):
# 直接调用 SciPy 的实现
row_ind, col_ind = linear_sum_assignment(cost_matrix)
return np.array(list(zip(row_ind, col_ind)))
def iou_batch(bb_test, bb_gt):
"""
计算两组边界框(格式为 [x1,y1,x2,y2])之间的 IOU
"""
bb_gt = np.expand_dims(bb_gt, 0)
bb_test = np.expand_dims(bb_test, 1)
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
return o
def convert_bbox_to_z(bbox):
"""
将边界框 [x1,y1,x2,y2] 转换为中心形式 [x, y, s, r],
其中 x,y 为中心,s 为面积,r 为宽高比
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w/2.
y = bbox[1] + h/2.
s = w * h # 面积
r = w / float(h)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
将中心形式 [x,y,s,r] 转换为 [x1,y1,x2,y2](若有 score 则附加)
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
else:
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
class KalmanBoxTracker:
"""
表示单个跟踪目标的内部状态,基于边界框的观测
"""
count = 0
def __init__(self, bbox):
# 定义恒速模型
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000. # 对不可观测的初始速度赋予高不确定性
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
def update(self, bbox):
"""用新的 bbox 更新状态"""
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
def predict(self):
"""预测下一个状态并返回预测的 bbox"""
if (self.kf.x[6] + self.kf.x[2]) <= 0:
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
def get_state(self):
"""返回当前的 bbox 估计"""
return convert_x_to_bbox(self.kf.x)
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
"""
将检测与跟踪器进行关联,返回匹配、未匹配的检测和未匹配的跟踪器
"""
if len(trackers) == 0:
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
cost_matrix = -iou_matrix # 因为 linear_assignment 求最小化
matched_indices = linear_assignment(cost_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
class Sort:
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
"""
初始化 SORT 参数
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
def update(self, dets=np.empty((0, 5))):
"""
更新每一帧的检测结果
参数:
dets - numpy 数组,格式为 [[x1,y1,x2,y2,score], ...]
返回: 带有跟踪ID的数组,最后一列为对象ID
"""
self.frame_count += 1
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(self.trackers):
pos = trk.predict()[0]
trks[t] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
# 更新匹配到的跟踪器
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
# 为未匹配的检测创建新的跟踪器
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :])
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 保证 ID 为正数
i -= 1
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))
def parse_args():
"""解析输入参数"""
parser = argparse.ArgumentParser(description='SORT demo')
parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]', action='store_true')
parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
parser.add_argument("--max_age",
help="Maximum number of frames to keep alive a track without associated detections.",
type=int, default=1)
parser.add_argument("--min_hits",
help="Minimum number of associated detections before track is initialised.",
type=int, default=3)
parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
args = parser.parse_args()
return args
if __name__ == '__main__':
#simple_test()
args = parse_args()
display = args.display
phase = args.phase
total_time = 0.0
total_frames = 0
colours = np.random.rand(32, 3) # 显示时用的颜色
if display:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
plt.ion()
fig = plt.figure()
ax1 = fig.add_subplot(111, aspect='equal')
if not os.path.exists('output'):
os.makedirs('output')
pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
for seq_dets_fn in glob.glob(pattern):
mot_tracker = Sort(max_age=args.max_age,
min_hits=args.min_hits,
iou_threshold=args.iou_threshold)
seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
seq = seq_dets_fn[seq_dets_fn.find('*'):].split(os.path.sep)[0]
with open(os.path.join('output', f'{seq}.txt'), 'w') as out_file:
print("Processing %s." % (seq))
for frame in range(int(seq_dets[:, 0].max())):
frame += 1 # 帧编号从1开始
dets = seq_dets[seq_dets[:, 0] == frame, 2:7]
dets[:, 2:4] += dets[:, 0:2] # 将 [x,y,w,h] 转换为 [x1,y1,x2,y2]
total_frames += 1
if display:
fn = os.path.join('mot_benchmark', phase, seq, 'img1', f'{frame:06d}.jpg')
im = io.imread(fn)
ax1.imshow(im)
plt.title(seq + ' Tracked Targets')
start_time = time.time()
trackers = mot_tracker.update(dets)
cycle_time = time.time() - start_time
total_time += cycle_time
for d in trackers:
print(f'{frame},{d[4]},{d[0]:.2f},{d[1]:.2f},{(d[2]-d[0]):.2f},{(d[3]-d[1]):.2f},1,-1,-1,-1', file=out_file)
if display:
d = d.astype(np.int32)
ax1.add_patch(patches.Rectangle((d[0], d[1]), d[2]-d[0], d[3]-d[1], fill=False, lw=3, ec=colours[d[4] % 32, :]))
if display:
fig.canvas.flush_events()
plt.draw()
ax1.cla()
print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
if display:
print("Note: to get real runtime results run without the option: --display")
它需要目标检测模型的置信度、检测框结果作为输入。下一篇文章会展示一份 demo,使用该模块,进行跟踪和绘制。