未验证 提交 c84153a3 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] Add OC_SORT tracker (#6272)

* add ocsort tracker

* add ocsort deploy

* merge develop

* fix ocsort tracker codes

* fix doc, test=document_fix

* fix doc, test=document_fix
上级 7bfddb0e
......@@ -20,10 +20,9 @@
| MOT-17 half train | YOLOv3 | 608x608 | - | 42.7 | 49.5 | 54.8 | - |[配置文件](./bytetrack_yolov3.yml) |
| MOT-17 half train | PP-YOLOE-l | 640x640 | - | 52.9 | 50.4 | 59.7 | - |[配置文件](./bytetrack_ppyoloe.yml) |
| MOT-17 half train | PP-YOLOE-l | 640x640 |PPLCNet| 52.9 | 51.7 | 58.8 | - |[配置文件](./bytetrack_ppyoloe_pplcnet.yml) |
| **mot17_ch** | YOLOX-x | 800x1440| - | 61.9 | 77.3 | 71.6 | - |[配置文件](./bytetrack_yolox.yml) |
| **mix_mot_ch** | YOLOX-x | 800x1440| - | 61.9 | 77.3 | 71.6 | - |[配置文件](./bytetrack_yolox.yml) |
| **mix_det** | YOLOX-x | 800x1440| - | 65.4 | 84.5 | 77.4 | - |[配置文件](./bytetrack_yolox.yml) |
**注意:**
- 检测任务相关配置和文档请查看[detector](detector/)
......@@ -43,7 +42,7 @@
**注意:**
- 模型权重下载链接在配置文件中的```det_weights``````reid_weights```,运行```tools/eval_mot.py```评估的命令即可自动下载,```reid_weights```若为None则表示不需要使用,ByteTrack默认不使用ReID权重。
- **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://dataset.bj.bcebos.com/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。
- **mix_det**是MOT17、crowdhuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。
- **mix_mot_ch**数据集,是MOT17、CrowdHuman组成的联合数据集,**mix_det**是MOT17、CrowdHuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。
- ByteTrack的训练是单独的检测器训练MOT数据集,推理是组装跟踪器去评估MOT指标,单独的检测模型也可以评估检测指标。
- ByteTrack的导出部署,是单独导出检测模型,再组装跟踪器运行的,参照[PP-Tracking](../../../deploy/pptracking/python/README.md)
......@@ -122,7 +121,7 @@ Step 2:导出ReID模型(可选步骤,默认不需要)
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/reid/deepsort_pplcnet.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
```
### 4. 用导出的模型基于Python去预测
### 5. 用导出的模型基于Python去预测
```bash
python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --save_mot_txts
......
简体中文 | [English](README.md)
# OC_SORT (Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [快速开始](#快速开始)
- [引用](#引用)
## 简介
[OC_SORT](https://arxiv.org/abs/2203.14360)(Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking)。此处提供了几个常用检测器的配置作为参考。由于训练数据集、输入尺度、训练epoch数、NMS阈值设置等的不同均会导致模型精度和性能的差异,请自行根据需求进行适配。
## 模型库
### OC_SORT在MOT-17 half Val Set上结果
| 检测训练数据集 | 检测器 | 输入尺度 | ReID | 检测mAP | MOTA | IDF1 | FPS | 配置文件 |
| :-------- | :----- | :----: | :----:|:------: | :----: |:-----: |:----:|:----: |
| MOT-17 half train | PP-YOLOE-l | 640x640 | - | 52.9 | 50.1 | 62.6 | - |[配置文件](./bytetrack_ppyoloe.yml) |
| **mot17_ch** | YOLOX-x | 800x1440| - | 61.9 | 75.5 | 77.0 | - |[配置文件](./ocsort_yolox.yml) |
**注意:**
- 模型权重下载链接在配置文件中的```det_weights``````reid_weights```,运行验证的命令即可自动下载,OC_SORT默认不需要```reid_weights```权重。
- **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://dataset.bj.bcebos.com/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。
- **mix_mot_ch**数据集,是MOT17、CrowdHuman组成的联合数据集,**mix_det**是MOT17、CrowdHuman、Cityscapes、ETHZ组成的联合数据集,数据集整理的格式和目录可以参考[此链接](https://github.com/ifzhang/ByteTrack#data-preparation),最终放置于`dataset/mot/`目录下。为了验证精度可以都用**MOT17-half val**数据集去评估。
- OC_SORT的训练是单独的检测器训练MOT数据集,推理是组装跟踪器去评估MOT指标,单独的检测模型也可以评估检测指标。
- OC_SORT的导出部署,是单独导出检测模型,再组装跟踪器运行的,参照[PP-Tracking](../../../deploy/pptracking/python)
- OC_SORT是PP-Human和PP-Vehicle等Pipeline分析项目跟踪方向的主要方案,具体使用参照[Pipeline](../../../deploy/pipeline)[MOT](../../../deploy/pipeline/docs/tutorials/mot.md)
## 快速开始
### 1. 训练
通过如下命令一键式启动训练和评估
```bash
python -m paddle.distributed.launch --log_dir=ppyoloe --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp
```
### 2. 评估
#### 2.1 评估检测效果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml
```
**注意:**
- 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```
#### 2.2 评估跟踪效果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/ocsort/ocsort_ppyoloe.yml --scaled=True
# 或者
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/ocsort/ocsort_yolox.yml --scaled=True
```
**注意:**
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE YOLOv3则为False,如果使用通用检测模型则为True, 默认值是False。
- 跟踪结果会存于`{output_dir}/mot_results/`中,里面每个视频序列对应一个txt,每个txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`, 此外`{output_dir}`可通过`--output_dir`设置。
### 3. 预测
使用单个GPU通过如下命令预测一个视频,并保存为视频
```bash
# 下载demo视频
wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/ocsort/ocsort_yolox.yml --video_file=mot17_demo.mp4 --scaled=True --save_videos
```
**注意:**
- 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
### 4. 导出预测模型
Step 1:导出检测模型
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/bytetrack/detector/yolox_x_24e_800x1440_mix_det.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/yolox_x_24e_800x1440_mix_det.pdparams
```
### 5. 用导出的模型基于Python去预测
```bash
python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/yolox_x_24e_800x1440_mix_det/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --save_mot_txts
```
**注意:**
- 运行前需要手动修改`tracker_config.yml`的跟踪器类型为`type: OCSORTTracker`
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
- 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`
## 引用
```
@article{cao2022observation,
title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking},
author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris},
journal={arXiv preprint arXiv:2203.14360},
year={2022}
}
```
# This config is an assembled config for ByteTrack MOT, used as eval/infer mode for MOT.
_BASE_: [
'../bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml',
'../bytetrack/_base_/mot17.yml',
'../bytetrack/_base_/ppyoloe_mot_reader_640x640.yml'
]
weights: output/ocsort_ppyoloe/model_final
log_iter: 20
snapshot_epoch: 2
metric: MOT # eval/infer mode, set 'COCO' can be training mode
num_classes: 1
architecture: ByteTrack
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams
ByteTrack:
detector: YOLOv3 # PPYOLOe version
reid: None
tracker: OCSORTTracker
det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
reid_weights: None
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 # 100
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.1 # 0.01 in original detector
nms_threshold: 0.4 # 0.6 in original detector
OCSORTTracker:
det_thresh: 0.4 # 0.6 in yolox ocsort
max_age: 30
min_hits: 3
iou_threshold: 0.3
delta_t: 3
inertia: 0.2
vertical_ratio: 0
min_box_area: 0
use_byte: False
# MOTDataset for MOT evaluation and inference
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT and ByteTrack
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
# This config is an assembled config for ByteTrack MOT, used as eval/infer mode for MOT.
_BASE_: [
'../bytetrack/detector/yolox_x_24e_800x1440_mix_det.yml',
'../bytetrack/_base_/mix_det.yml',
'../bytetrack/_base_/yolox_mot_reader_800x1440.yml'
]
weights: output/ocsort_yolox/model_final
log_iter: 20
snapshot_epoch: 2
metric: MOT # eval/infer mode
num_classes: 1
architecture: ByteTrack
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/yolox_x_300e_coco.pdparams
ByteTrack:
detector: YOLOX
reid: None
tracker: OCSORTTracker
det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/yolox_x_24e_800x1440_mix_mot_ch.pdparams
reid_weights: None
depth_mult: 1.33
width_mult: 1.25
YOLOX:
backbone: CSPDarkNet
neck: YOLOCSPPAN
head: YOLOXHead
input_size: [800, 1440]
size_stride: 32
size_range: [18, 22] # multi-scale range [576*1024 ~ 800*1440], w/h ratio=1.8
CSPDarkNet:
arch: "X"
return_idx: [2, 3, 4]
depthwise: False
YOLOCSPPAN:
depthwise: False
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
YOLOXHead:
l1_epoch: 20
depthwise: False
loss_weight: {cls: 1.0, obj: 1.0, iou: 5.0, l1: 1.0}
assigner:
name: SimOTAAssigner
candidate_topk: 10
use_vfl: False
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.1
nms_threshold: 0.7
# For speed while keep high mAP, you can modify 'nms_top_k' to 1000 and 'keep_top_k' to 100, the mAP will drop about 0.1%.
# For high speed demo, you can modify 'score_threshold' to 0.25 and 'nms_threshold' to 0.45, but the mAP will drop a lot.
OCSORTTracker:
det_thresh: 0.6
max_age: 30
min_hits: 3
iou_threshold: 0.3
delta_t: 3
inertia: 0.2
vertical_ratio: 1.6
min_box_area: 100
use_byte: False
# MOTDataset for MOT evaluation and inference
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT and ByteTrack
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
......@@ -2,7 +2,8 @@
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
type: JDETracker
type: OCSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker']
# BYTETracker
JDETracker:
......@@ -13,3 +14,15 @@ JDETracker:
match_thres: 0.9
min_box_area: 0
vertical_ratio: 0 # 1.6 for pedestrian
OCSORTTracker:
det_thresh: 0.4
max_age: 30
min_hits: 3
iou_threshold: 0.3
delta_t: 3
inertia: 0.2
vertical_ratio: 0
min_box_area: 0
use_byte: False
......@@ -117,7 +117,7 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=mot_ppyoloe_l_36e_p
## 3. 对ByteTrack模型的导出和预测
## 3. 对ByteTrack和OC_SORT模型的导出和预测
### 3.1 导出预测模型
```bash
# 导出PPYOLOe行人检测模型
......@@ -136,7 +136,8 @@ python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/pp
python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --reid_model_dir=output_inference/deepsort_pplcnet/ --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5 --save_mot_txts --save_images
```
**注意:**
- 运行前需要确认`tracker_config.yml`的跟踪器类型为`type: JDETracker`
- 运行ByteTrack模型需要确认`tracker_config.yml`的跟踪器类型为`type: JDETracker`
- 可切换`tracker_config.yml`的跟踪器类型为`type: OCSORTTracker`运行OC_SORT模型。
- ByteTrack模型是加载导出的检测器和单独配置的`--tracker_config`文件运行的,为了实时跟踪所以不需要reid模型,`--reid_model_dir`表示reid导出模型的路径,默认为空,加不加具体视效果而定;
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_images`表示保存跟踪结果可视化图片。
- 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`
......
......@@ -14,6 +14,8 @@
from . import jde_matching
from . import deepsort_matching
from . import ocsort_matching
from .jde_matching import *
from .deepsort_matching import *
from .ocsort_matching import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py
"""
import os
import numpy as np
def iou_batch(bboxes1, bboxes2):
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) *
(bboxes1[..., 3] - bboxes1[..., 1]) +
(bboxes2[..., 2] - bboxes2[..., 0]) *
(bboxes2[..., 3] - bboxes2[..., 1]) - wh)
return (o)
def speed_direction_batch(dets, tracks):
tracks = tracks[..., np.newaxis]
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (
tracks[:, 1] + tracks[:, 3]) / 2.0
dx = CX1 - CX2
dy = CY1 - CY2
norm = np.sqrt(dx**2 + dy**2) + 1e-6
dx = dx / norm
dy = dy / norm
return dy, dx # size: num_track x num_det
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i], i] for i in x if i >= 0]) #
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def associate(detections, trackers, iou_threshold, velocities, previous_obs,
vdc_weight):
if (len(trackers) == 0):
return np.empty(
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
(0, 5), dtype=int)
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(
detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost))
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if (d not in matched_indices[:, 0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if (t not in matched_indices[:, 1]):
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if (iou_matrix[m[0], m[1]] < iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if (len(matches) == 0):
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
......@@ -16,8 +16,10 @@ from . import base_jde_tracker
from . import base_sde_tracker
from . import jde_tracker
from . import deepsort_tracker
from . import ocsort_tracker
from .base_jde_tracker import *
from .base_sde_tracker import *
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py
"""
import numpy as np
from filterpy.kalman import KalmanFilter
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch
def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
return observations[cur_age - dt]
max_age = max(observations.keys())
return observations[max_age]
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.
y = bbox[1] + h / 2.
s = w * h # scale is just area
r = w / float(h + 1e-6)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if (score == None):
return np.array(
[x[0] - w / 2., x[1] - h / 2., x[0] + w / 2.,
x[1] + h / 2.]).reshape((1, 4))
else:
score = np.array([score])
return np.array([
x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score
]).reshape((1, 5))
def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
speed = np.array([cy2 - cy1, cx2 - cx1])
norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6
return speed / norm
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
Args:
bbox (np.array): bbox in [x1,y1,x2,y2,score] format.
delta_t (int): delta_t of previous observation
"""
count = 0
def __init__(self, bbox, delta_t=3):
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000.
# give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.score = bbox[4]
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
"""
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
"""
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
self.observations = dict()
self.history_observations = []
self.velocity = None
self.delta_t = delta_t
def update(self, bbox):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
if self.last_observation.sum() >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
if self.age - dt in self.observations:
previous_box = self.observations[self.age - dt]
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
self.velocity = speed_direction(previous_box, bbox)
"""
Insert new observations. This is a ugly way to maintain both self.observations
and self.history_observations. Bear it for the moment.
"""
self.last_observation = bbox
self.observations[self.age] = bbox
self.history_observations.append(bbox)
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
else:
self.kf.update(bbox)
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if ((self.kf.x[6] + self.kf.x[2]) <= 0):
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if (self.time_since_update > 0):
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x, score=self.score))
return self.history[-1]
def get_state(self):
return convert_x_to_bbox(self.kf.x, score=self.score)
class OCSORTTracker(object):
"""
OCSORT tracker, support single class
Args:
det_thresh (float): threshold of detection score
max_age (int): maximum number of missed misses before a track is deleted
min_hits (int): minimum hits for associate
iou_threshold (float): iou threshold for associate
delta_t (int): delta_t of previous observation
inertia (float): vdc_weight of angle_diff_cost for associate
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
min_box_area (int): min box area to filter out low quality boxes
use_byte (bool): Whether use ByteTracker, default False
"""
def __init__(self,
det_thresh=0.6,
max_age=30,
min_hits=3,
iou_threshold=0.3,
delta_t=3,
inertia=0.2,
vertical_ratio=-1,
min_box_area=0,
use_byte=False):
self.det_thresh = det_thresh
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.delta_t = delta_t
self.inertia = inertia
self.vertical_ratio = vertical_ratio
self.min_box_area = min_box_area
self.use_byte = use_byte
self.trackers = []
self.frame_count = 0
KalmanBoxTracker.count = 0
def update(self, pred_dets, pred_embs=None):
"""
Args:
pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512], default as None.
Return:
tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'.
"""
if pred_dets is None:
return np.empty((0, 6))
self.frame_count += 1
bboxes = pred_dets[:, 2:]
scores = pred_dets[:, 1:2]
dets = np.concatenate((bboxes, scores), axis=1)
scores = scores.squeeze(-1)
inds_low = scores > 0.1
inds_high = scores < self.det_thresh
inds_second = np.logical_and(inds_low, inds_high)
# self.det_thresh > score > 0.1, for second matching
dets_second = dets[inds_second] # detections for second matching
remain_inds = scores > self.det_thresh
dets = dets[remain_inds]
# get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
"""
Second round of associaton by OCR
"""
# BYTE association
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[
0] > 0:
u_trks = trks[unmatched_trks]
iou_left = iou_batch(
dets_second,
u_trks) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :])
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = iou_batch(left_dets, left_trks)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[
1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets,
np.array(to_remove_det_indices))
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
for m in unmatched_trks:
self.trackers[m].update(None)
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
d = trk.get_state()[0]
else:
d = trk.last_observation # tlbr + score
if (trk.time_since_update < 1) and (
trk.hit_streak >= self.min_hits or
self.frame_count <= self.min_hits):
# +1 as MOT benchmark requires positive
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
i -= 1
# remove dead tracklet
if (trk.time_since_update > self.max_age):
self.trackers.pop(i)
if (len(ret) > 0):
return np.concatenate(ret)
return np.empty((0, 6))
......@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path)
from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video
from mot.tracker import JDETracker, DeepSORTTracker
from mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker
from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic
from mot.visualize import plot_tracking, plot_tracking_dict
......@@ -142,8 +142,10 @@ class SDE_Detector(Detector):
# tracker config
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
self.use_ocsort_tracker = True if tracker_cfg[
'type'] == 'OCSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
......@@ -161,6 +163,28 @@ class SDE_Detector(Detector):
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio, )
elif self.use_ocsort_tracker:
det_thresh = cfg.get('det_thresh', 0.4)
max_age = cfg.get('max_age', 30)
min_hits = cfg.get('min_hits', 3)
iou_threshold = cfg.get('iou_threshold', 0.3)
delta_t = cfg.get('delta_t', 3)
inertia = cfg.get('inertia', 0.2)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
use_byte = cfg.get('use_byte', False)
self.tracker = OCSORTTracker(
det_thresh=det_thresh,
max_age=max_age,
min_hits=min_hits,
iou_threshold=iou_threshold,
delta_t=delta_t,
inertia=inertia,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
use_byte=use_byte)
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
......@@ -283,6 +307,32 @@ class SDE_Detector(Detector):
feat_data['feat'] = _feat
tracking_outs['feat_data'].update({_imgname: feat_data})
return tracking_outs
elif self.use_ocsort_tracker:
# use OCSORTTracker, only support singe class
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for t in online_targets:
tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
tscore = float(t[4])
tid = int(t[5])
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
if tlwh[2] * tlwh[3] > 0:
online_tlwhs[0].append(tlwh)
online_ids[0].append(tid)
online_scores[0].append(tscore)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list)
......@@ -523,7 +573,7 @@ class SDE_Detector(Detector):
online_tlwhs, online_scores, online_ids = mot_results[0]
# flow statistic for one class, and only for bytetracker
if num_classes == 1 and not self.use_deepsort_tracker:
if num_classes == 1 and not self.use_deepsort_tracker and not self.use_ocsort_tracker:
result = (frame_id + 1, online_tlwhs[0], online_scores[0],
online_ids[0])
statistic = flow_statistic(
......@@ -533,8 +583,8 @@ class SDE_Detector(Detector):
records = statistic['records']
fps = 1. / timer.duration
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
if self.use_deepsort_tracker or self.use_ocsort_tracker:
# use DeepSORTTracker or OCSORTTracker, only support singe class
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
......
......@@ -2,7 +2,7 @@
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
type: JDETracker # 'JDETracker' or 'DeepSORTTracker'
type: OCSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker', 'DeepSORTTracker']
# BYTETracker
JDETracker:
......@@ -14,6 +14,19 @@ JDETracker:
min_box_area: 0
vertical_ratio: 0 # 1.6 for pedestrian
OCSORTTracker:
det_thresh: 0.4
max_age: 30
min_hits: 3
iou_threshold: 0.3
delta_t: 3
inertia: 0.2
min_box_area: 0
vertical_ratio: 0
use_byte: False
DeepSORTTracker:
input_size: [64, 192]
min_box_area: 0
......
......@@ -29,7 +29,7 @@ from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker
from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker
from ppdet.modeling.architectures import YOLOX
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
import ppdet.utils.stats as stats
......@@ -370,7 +370,29 @@ class Tracker(object):
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
elif isinstance(tracker, OCSORTTracker):
# OC_SORT Tracker
online_targets = tracker.update(pred_dets_old, pred_embs)
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
tscore = float(t[4])
tid = int(t[5])
if tlwh[2] * tlwh[3] > 0:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
# save results
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes)
else:
raise ValueError(tracker)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
......
......@@ -14,6 +14,8 @@
from . import jde_matching
from . import deepsort_matching
from . import ocsort_matching
from .jde_matching import *
from .deepsort_matching import *
from .ocsort_matching import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py
"""
import os
import numpy as np
def iou_batch(bboxes1, bboxes2):
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
area = w * h
iou_matrix = area / ((bboxes1[..., 2] - bboxes1[..., 0]) *
(bboxes1[..., 3] - bboxes1[..., 1]) +
(bboxes2[..., 2] - bboxes2[..., 0]) *
(bboxes2[..., 3] - bboxes2[..., 1]) - area)
return iou_matrix
def speed_direction_batch(dets, tracks):
tracks = tracks[..., np.newaxis]
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (
tracks[:, 1] + tracks[:, 3]) / 2.0
dx = CX1 - CX2
dy = CY1 - CY2
norm = np.sqrt(dx**2 + dy**2) + 1e-6
dx = dx / norm
dy = dy / norm
return dy, dx
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i], i] for i in x if i >= 0])
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def associate(detections, trackers, iou_threshold, velocities, previous_obs,
vdc_weight):
if (len(trackers) == 0):
return np.empty(
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
(0, 5), dtype=int)
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(
detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost))
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if (d not in matched_indices[:, 0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if (t not in matched_indices[:, 1]):
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if (iou_matrix[m[0], m[1]] < iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if (len(matches) == 0):
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
......@@ -16,8 +16,10 @@ from . import base_jde_tracker
from . import base_sde_tracker
from . import jde_tracker
from . import deepsort_tracker
from . import ocsort_tracker
from .base_jde_tracker import *
from .base_sde_tracker import *
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py
"""
import numpy as np
from filterpy.kalman import KalmanFilter
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch
from ppdet.core.workspace import register, serializable
def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
return observations[cur_age - dt]
max_age = max(observations.keys())
return observations[max_age]
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.
y = bbox[1] + h / 2.
s = w * h # scale is just area
r = w / float(h + 1e-6)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if (score == None):
return np.array(
[x[0] - w / 2., x[1] - h / 2., x[0] + w / 2.,
x[1] + h / 2.]).reshape((1, 4))
else:
score = np.array([score])
return np.array([
x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score
]).reshape((1, 5))
def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
speed = np.array([cy2 - cy1, cx2 - cx1])
norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6
return speed / norm
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
Args:
bbox (np.array): bbox in [x1,y1,x2,y2,score] format.
delta_t (int): delta_t of previous observation
"""
count = 0
def __init__(self, bbox, delta_t=3):
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000.
# give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.score = bbox[4]
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
"""
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
"""
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
self.observations = dict()
self.history_observations = []
self.velocity = None
self.delta_t = delta_t
def update(self, bbox):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
if self.last_observation.sum() >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
if self.age - dt in self.observations:
previous_box = self.observations[self.age - dt]
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
self.velocity = speed_direction(previous_box, bbox)
"""
Insert new observations. This is a ugly way to maintain both self.observations
and self.history_observations. Bear it for the moment.
"""
self.last_observation = bbox
self.observations[self.age] = bbox
self.history_observations.append(bbox)
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
else:
self.kf.update(bbox)
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if ((self.kf.x[6] + self.kf.x[2]) <= 0):
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if (self.time_since_update > 0):
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x, score=self.score))
return self.history[-1]
def get_state(self):
return convert_x_to_bbox(self.kf.x, score=self.score)
@register
@serializable
class OCSORTTracker(object):
"""
OCSORT tracker, support single class
Args:
det_thresh (float): threshold of detection score
max_age (int): maximum number of missed misses before a track is deleted
min_hits (int): minimum hits for associate
iou_threshold (float): iou threshold for associate
delta_t (int): delta_t of previous observation
inertia (float): vdc_weight of angle_diff_cost for associate
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
min_box_area (int): min box area to filter out low quality boxes
use_byte (bool): Whether use ByteTracker, default False
"""
def __init__(self,
det_thresh=0.6,
max_age=30,
min_hits=3,
iou_threshold=0.3,
delta_t=3,
inertia=0.2,
vertical_ratio=-1,
min_box_area=0,
use_byte=False):
self.det_thresh = det_thresh
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.delta_t = delta_t
self.inertia = inertia
self.vertical_ratio = vertical_ratio
self.min_box_area = min_box_area
self.use_byte = use_byte
self.trackers = []
self.frame_count = 0
KalmanBoxTracker.count = 0
def update(self, pred_dets, pred_embs=None):
"""
Args:
pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512], default as None.
Return:
tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'.
"""
if pred_dets is None:
return np.empty((0, 6))
self.frame_count += 1
bboxes = pred_dets[:, 2:]
scores = pred_dets[:, 1:2]
dets = np.concatenate((bboxes, scores), axis=1)
scores = scores.squeeze(-1)
inds_low = scores > 0.1
inds_high = scores < self.det_thresh
inds_second = np.logical_and(inds_low, inds_high)
# self.det_thresh > score > 0.1, for second matching
dets_second = dets[inds_second] # detections for second matching
remain_inds = scores > self.det_thresh
dets = dets[remain_inds]
# get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
"""
Second round of associaton by OCR
"""
# BYTE association
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[
0] > 0:
u_trks = trks[unmatched_trks]
iou_left = iou_batch(
dets_second,
u_trks) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :])
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = iou_batch(left_dets, left_trks)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[
1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets,
np.array(to_remove_det_indices))
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
for m in unmatched_trks:
self.trackers[m].update(None)
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
d = trk.get_state()[0]
else:
d = trk.last_observation # tlbr + score
if (trk.time_since_update < 1) and (
trk.hit_streak >= self.min_hits or
self.frame_count <= self.min_hits):
# +1 as MOT benchmark requires positive
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
i -= 1
# remove dead tracklet
if (trk.time_since_update > self.max_age):
self.trackers.pop(i)
if (len(ret) > 0):
return np.concatenate(ret)
return np.empty((0, 6))
......@@ -10,8 +10,11 @@ Cython
pycocotools
#xtcocotools==1.6 #only for crowdpose
setuptools>=42.0.0
pyclipper
# for mot
lap
sklearn
motmetrics
openpyxl
pyclipper
filterpy
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册