当前位置: 首页 > news >正文

适配python3.9的 SORT算法

简单地更改了 sort.py 函数的接口,核心思想、处理操作并不改变。

源代码链接:https://github.com/abewley/sort

import os
import numpy as np
import glob
import time
import argparse
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment  # 使用 SciPy 实现分配

np.random.seed(0)

def simple_test():
    mot_tracker = Sort(max_age=1, min_hits=1, iou_threshold=0.3)
    
    frame1_dets = np.array([
        [100, 100, 150, 180, 0.9],
        [200, 150, 240, 210, 0.85]
    ])
    frame2_dets = np.array([
        [110, 120, 170, 210, 0.88]
    ])
    
    start_time = time.time()
    trackers1 = mot_tracker.update(frame1_dets)
    cycle_time = time.time() - start_time
    print("Frame1 trackers:", trackers1, "time:", cycle_time)
    
    # 第2帧
    start_time = time.time()
    trackers2 = mot_tracker.update(frame2_dets)
    cycle_time = time.time() - start_time
    print("Frame2 trackers:", trackers2, "time:", cycle_time)


def linear_assignment(cost_matrix):
    # 直接调用 SciPy 的实现
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    return np.array(list(zip(row_ind, col_ind)))

def iou_batch(bb_test, bb_gt):
    """
    计算两组边界框(格式为 [x1,y1,x2,y2])之间的 IOU
    """
    bb_gt = np.expand_dims(bb_gt, 0)
    bb_test = np.expand_dims(bb_test, 1)
    
    xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
    yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
    xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
    yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
    w = np.maximum(0., xx2 - xx1)
    h = np.maximum(0., yy2 - yy1)
    wh = w * h
    o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
              + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
    return o

def convert_bbox_to_z(bbox):
    """
    将边界框 [x1,y1,x2,y2] 转换为中心形式 [x, y, s, r],
    其中 x,y 为中心,s 为面积,r 为宽高比
    """
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    x = bbox[0] + w/2.
    y = bbox[1] + h/2.
    s = w * h    # 面积
    r = w / float(h)
    return np.array([x, y, s, r]).reshape((4, 1))

def convert_x_to_bbox(x, score=None):
    """
    将中心形式 [x,y,s,r] 转换为 [x1,y1,x2,y2](若有 score 则附加)
    """
    w = np.sqrt(x[2] * x[3])
    h = x[2] / w
    if score is None:
        return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
    else:
        return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))

class KalmanBoxTracker:
    """
    表示单个跟踪目标的内部状态,基于边界框的观测
    """
    count = 0
    def __init__(self, bbox):
        # 定义恒速模型
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
                              [0, 1, 0, 0, 0, 1, 0],
                              [0, 0, 1, 0, 0, 0, 1],
                              [0, 0, 0, 1, 0, 0, 0],
                              [0, 0, 0, 0, 1, 0, 0],
                              [0, 0, 0, 0, 0, 1, 0],
                              [0, 0, 0, 0, 0, 0, 1]])
        self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
                              [0, 1, 0, 0, 0, 0, 0],
                              [0, 0, 1, 0, 0, 0, 0],
                              [0, 0, 0, 1, 0, 0, 0]])
        self.kf.R[2:, 2:] *= 10.
        self.kf.P[4:, 4:] *= 1000.  # 对不可观测的初始速度赋予高不确定性
        self.kf.P *= 10.
        self.kf.Q[-1, -1] *= 0.01
        self.kf.Q[4:, 4:] *= 0.01

        self.kf.x[:4] = convert_bbox_to_z(bbox)
        self.time_since_update = 0
        self.id = KalmanBoxTracker.count
        KalmanBoxTracker.count += 1
        self.history = []
        self.hits = 0
        self.hit_streak = 0
        self.age = 0

    def update(self, bbox):
        """用新的 bbox 更新状态"""
        self.time_since_update = 0
        self.history = []
        self.hits += 1
        self.hit_streak += 1
        self.kf.update(convert_bbox_to_z(bbox))

    def predict(self):
        """预测下一个状态并返回预测的 bbox"""
        if (self.kf.x[6] + self.kf.x[2]) <= 0:
            self.kf.x[6] *= 0.0
        self.kf.predict()
        self.age += 1
        if self.time_since_update > 0:
            self.hit_streak = 0
        self.time_since_update += 1
        self.history.append(convert_x_to_bbox(self.kf.x))
        return self.history[-1]

    def get_state(self):
        """返回当前的 bbox 估计"""
        return convert_x_to_bbox(self.kf.x)

def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
    """
    将检测与跟踪器进行关联,返回匹配、未匹配的检测和未匹配的跟踪器
    """
    if len(trackers) == 0:
        return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)

    iou_matrix = iou_batch(detections, trackers)
    if min(iou_matrix.shape) > 0:
        cost_matrix = -iou_matrix  # 因为 linear_assignment 求最小化
        matched_indices = linear_assignment(cost_matrix)
    else:
        matched_indices = np.empty(shape=(0, 2))

    unmatched_detections = []
    for d, det in enumerate(detections):
        if d not in matched_indices[:, 0]:
            unmatched_detections.append(d)
    unmatched_trackers = []
    for t, trk in enumerate(trackers):
        if t not in matched_indices[:, 1]:
            unmatched_trackers.append(t)

    matches = []
    for m in matched_indices:
        if iou_matrix[m[0], m[1]] < iou_threshold:
            unmatched_detections.append(m[0])
            unmatched_trackers.append(m[1])
        else:
            matches.append(m.reshape(1, 2))
    if len(matches) == 0:
        matches = np.empty((0, 2), dtype=int)
    else:
        matches = np.concatenate(matches, axis=0)

    return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

class Sort:
    def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
        """
        初始化 SORT 参数
        """
        self.max_age = max_age
        self.min_hits = min_hits
        self.iou_threshold = iou_threshold
        self.trackers = []
        self.frame_count = 0

    def update(self, dets=np.empty((0, 5))):
        """
        更新每一帧的检测结果
        参数:
          dets - numpy 数组,格式为 [[x1,y1,x2,y2,score], ...]
        返回: 带有跟踪ID的数组,最后一列为对象ID
        """
        self.frame_count += 1
        trks = np.zeros((len(self.trackers), 5))
        to_del = []
        ret = []
        for t, trk in enumerate(self.trackers):
            pos = trk.predict()[0]
            trks[t] = [pos[0], pos[1], pos[2], pos[3], 0]
            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            self.trackers.pop(t)
        matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)

        # 更新匹配到的跟踪器
        for m in matched:
            self.trackers[m[1]].update(dets[m[0], :])

        # 为未匹配的检测创建新的跟踪器
        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i, :])
            self.trackers.append(trk)

        i = len(self.trackers)
        for trk in reversed(self.trackers):
            d = trk.get_state()[0]
            if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
                ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))  # +1 保证 ID 为正数
            i -= 1
            if trk.time_since_update > self.max_age:
                self.trackers.pop(i)
        if len(ret) > 0:
            return np.concatenate(ret)
        return np.empty((0, 5))

def parse_args():
    """解析输入参数"""
    parser = argparse.ArgumentParser(description='SORT demo')
    parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]', action='store_true')
    parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
    parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
    parser.add_argument("--max_age", 
                        help="Maximum number of frames to keep alive a track without associated detections.", 
                        type=int, default=1)
    parser.add_argument("--min_hits", 
                        help="Minimum number of associated detections before track is initialised.", 
                        type=int, default=3)
    parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    #simple_test()
    args = parse_args()
    display = args.display
    phase = args.phase
    total_time = 0.0
    total_frames = 0
    colours = np.random.rand(32, 3)  # 显示时用的颜色

    if display:
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        plt.ion()
        fig = plt.figure()
        ax1 = fig.add_subplot(111, aspect='equal')

    if not os.path.exists('output'):
        os.makedirs('output')
    pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
    for seq_dets_fn in glob.glob(pattern):
        mot_tracker = Sort(max_age=args.max_age, 
                           min_hits=args.min_hits,
                           iou_threshold=args.iou_threshold)
        seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
        seq = seq_dets_fn[seq_dets_fn.find('*'):].split(os.path.sep)[0]
        
        with open(os.path.join('output', f'{seq}.txt'), 'w') as out_file:
            print("Processing %s." % (seq))
            for frame in range(int(seq_dets[:, 0].max())):
                frame += 1  # 帧编号从1开始
                dets = seq_dets[seq_dets[:, 0] == frame, 2:7]
                dets[:, 2:4] += dets[:, 0:2]  # 将 [x,y,w,h] 转换为 [x1,y1,x2,y2]
                total_frames += 1

                if display:
                    fn = os.path.join('mot_benchmark', phase, seq, 'img1', f'{frame:06d}.jpg')
                    im = io.imread(fn)
                    ax1.imshow(im)
                    plt.title(seq + ' Tracked Targets')

                start_time = time.time()
                trackers = mot_tracker.update(dets)
                cycle_time = time.time() - start_time
                total_time += cycle_time

                for d in trackers:
                    print(f'{frame},{d[4]},{d[0]:.2f},{d[1]:.2f},{(d[2]-d[0]):.2f},{(d[3]-d[1]):.2f},1,-1,-1,-1', file=out_file)
                    if display:
                        d = d.astype(np.int32)
                        ax1.add_patch(patches.Rectangle((d[0], d[1]), d[2]-d[0], d[3]-d[1], fill=False, lw=3, ec=colours[d[4] % 32, :]))

                if display:
                    fig.canvas.flush_events()
                    plt.draw()
                    ax1.cla()

    print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))
    if display:
        print("Note: to get real runtime results run without the option: --display")

它需要目标检测模型的置信度、检测框结果作为输入。下一篇文章会展示一份 demo,使用该模块,进行跟踪和绘制。

相关文章:

  • C语言运算符优先级速记口诀
  • 基于查表法的 CRC8 / CRC16 / CRC32校验解析
  • PowerBI 条形图显示数值和百分比
  • Vue使用el-table给每一行数据上面增加一行自定义合并行
  • C++算法(1):stringstream详解,高效字符串处理与类型转换的利器
  • 【NLP 55、强化学习与NLP】
  • OpenHarmony Camera开发指导(二):相机设备管理(ArkTS)
  • ALOPS智能化运维管理平台
  • 浅析Centos7安装Oracle12数据库
  • 详解正则表达式中的?:、?= 、 ?! 、?<=、?<!
  • 火语言RPA--增加减少时间
  • EN控制同步整流WD1020 ,3.0V-21V 的宽 VIN 输入范围,0.9V-20V 的宽输出电压范围
  • Android activity属性taskAffinity的作用
  • call、bind、apply
  • MySQL Error Log
  • 【第16届】蓝桥杯C++b组--记录一次被薄纱的心情
  • 艾伦·图灵:计算机科学与人工智能之父
  • 08-JVM 面试题-mk
  • KWDB创作者计划—KWDB认知引擎:数据流动架构与时空感知计算的范式突破
  • 20250412_代码笔记_CVRProblemDef
  • 谷多网站/免费推广网站注册入口
  • 缪斯设计上海/seo搜索引擎优化内容
  • 日本可以自己做网站吗?/网络推广员上班靠谱吗
  • 怎么做期货网站/湖南百度推广开户
  • dw对网站建设有哪些作用/好用的种子搜索引擎
  • 专业做能源招聘的网站/windows优化大师靠谱吗