未验证 提交 772d1801 编写于 作者: L LokeZhou 提交者: GitHub

[MOT] Add BoT-SORT and MOT speed optimization (#7591)

* Add BOTSORTTracker
* pptracking/python/mot/tracker/botsort_tracker.py
* ppdet/modeling/mot/tracker/botsort_tracker.py

fix OCSORTTracker

* * botsort kalman_filter.py float64->float32
* requirements.txt add numba

* [Mot] kalman_filter.py lazy import numba
* add code source and modify comments

* [MOT] fix numba lazy import bug

* [MOT] botsort add README.md

* delete comment code
上级 2acf9cfc
English | [简体中文](README_cn.md)
# BOT_SORT (BoT-SORT: Robust Associations Multi-Pedestrian Tracking)
## content
- [introduction](#introduction)
- [model zoo](#modelzoo)
- [Quick Start](#QuickStart)
- [Citation](Citation)
## introduction
[BOT_SORT](https://arxiv.org/pdf/2206.14651v2.pdf)(BoT-SORT: Robust Associations Multi-Pedestrian Tracking). The configuration of common detectors is provided here for reference. Because different training data sets, input scales, number of training epochs, NMS threshold settings, etc. will lead to differences in model accuracy and performance, please adapt according to your needs
## modelzoo
### BOT_SORT在MOT-17 half Val Set
| Dataset | detector | input size | detector mAP | MOTA | IDF1 | config |
| :-------- | :----- | :----: | :------: | :----: |:-----: |:----: |
| MOT-17 half train | PP-YOLOE-l | 640x640 | 52.7 | 55.5 | 64.2 |[config](./botsort_ppyoloe.yml) |
**Attention:**
- Model weight download link in the configuration file ` ` ` det_ Weights ` ` `, run the verification command to automatically download.
- **MOT17-half train** is a data set composed of pictures and labels of the first half frames of each video in the MOT17 train sequence (7 in total). To verify the accuracy, we can use the **MOT17-half val** to eval,It is composed of the second half frame of each video,download [link](https://bj.bcebos.com/v1/paddledet/data/mot/MOT17.zip),decompression `dataset/mot/`
- BOT_ SORT training is a separate detector training MOT dataset, reasoning is to assemble a tracker to evaluate MOT indicators, and a separate detection model can also evaluate detection indicators.
- BOT_SORT export deployment is to export the detection model separately and then assemble the tracker for operation. Refer to [PP-Tracking](../../../deploy/pptracking/python)
- BOT_SORT is the main scheme for PP Human, PP Vehicle and other pipelines to analyze the project tracking direction. For specific use, please refer to [Pipeline](../../../deploy/pipeline) and [MOT](../../../deploy/pipeline/docs/tutorials/pphuman_mot.md).
## QuickStart
### 1. train
Start training and evaluation with the following command
```bash
#Single gpu
CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp
#Multi gpu
python -m paddle.distributed.launch --log_dir=ppyoloe --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp
```
### 2. evaluate
#### 2.1 detection
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml
```
**Attention:**
- eval detection use ```tools/eval.py```,eval mot use ```tools/eval_mot.py```.
#### 2.2 mot
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/botsort/botsort_ppyoloe.yml --scaled=True
```
**Attention:**
- `--scaled` indicates whether the coordinates of the output results of the model have been scaled back to the original drawing. If the detection model used is JDE YOLOv3, it is false. If the universal detection model is used, it is true. The default value is false.
- mot result save `{output_dir}/mot_results/`,each video sequence in it corresponds to a txt, and each line of information in each txt file is `frame,id,x1,y1,w,h,score,-1,-1,-1`, and `{output_dir}` could use `--output_dir` to set.
### 3. export detection model
```bash
python tools/export_model.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --output_dir=output_inference -o weights=https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
```
### 4. Use the export model to predict
```bash
# download demo video
wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
CUDA_VISIBLE_DEVICES=0 python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5
```
**Attention:**
- You must fix `tracker_config.yml` tracker `type: BOTSORTTracker`,if you want to use BOT_SORT.
- The tracking model is used to predict videos. It does not support prediction of a single image. By default, the videos with visualized tracking results are saved. You can add `--save_mot_txts` (save a txt for each video) or `--save_mot_txt_per_img`(Save a txt for each image) or `--save_images` save the visualization picture of tracking results.
- Each line of the trace result txt file format `frame,id,x1,y1,w,h,score,-1,-1,-1`
## Citation
```
@article{aharon2022bot,
title={BoT-SORT: Robust Associations Multi-Pedestrian Tracking},
author={Aharon, Nir and Orfaig, Roy and Bobrovsky, Ben-Zion},
journal={arXiv preprint arXiv:2206.14651},
year={2022}
}
```
简体中文 | [English](README.md)
# BOT_SORT (BoT-SORT: Robust Associations Multi-Pedestrian Tracking)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [快速开始](#快速开始)
- [引用](#引用)
## 简介
[BOT_SORT](https://arxiv.org/pdf/2206.14651v2.pdf)(BoT-SORT: Robust Associations Multi-Pedestrian Tracking)。此处提供了常用检测器的配置作为参考。由于训练数据集、输入尺度、训练epoch数、NMS阈值设置等的不同均会导致模型精度和性能的差异,请自行根据需求进行适配。
## 模型库
### BOT_SORT在MOT-17 half Val Set上结果
| 检测训练数据集 | 检测器 | 输入尺度 | 检测mAP | MOTA | IDF1 | 配置文件 |
| :-------- | :----- | :----: | :------: | :----: |:-----: |:----: |
| MOT-17 half train | PP-YOLOE-l | 640x640 | 52.7 | 55.5 | 64.2 |[配置文件](./botsort_ppyoloe.yml) |
**注意:**
- 模型权重下载链接在配置文件中的```det_weights```,运行验证的命令即可自动下载。
- **MOT17-half train**是MOT17的train序列(共7个)每个视频的前一半帧的图片和标注组成的数据集,而为了验证精度可以都用**MOT17-half val**数据集去评估,它是每个视频的后一半帧组成的,数据集可以从[此链接](https://bj.bcebos.com/v1/paddledet/data/mot/MOT17.zip)下载,并解压放在`dataset/mot/`文件夹下。
- BOT_SORT的训练是单独的检测器训练MOT数据集,推理是组装跟踪器去评估MOT指标,单独的检测模型也可以评估检测指标。
- BOT_SORT的导出部署,是单独导出检测模型,再组装跟踪器运行的,参照[PP-Tracking](../../../deploy/pptracking/python)
- BOT_SORT是PP-Human和PP-Vehicle等Pipeline分析项目跟踪方向的主要方案,具体使用参照[Pipeline](../../../deploy/pipeline)[MOT](../../../deploy/pipeline/docs/tutorials/pphuman_mot.md)
## 快速开始
### 1. 训练
通过如下命令一键式启动训练和评估
```bash
#单卡训练
CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp
#多卡训练
python -m paddle.distributed.launch --log_dir=ppyoloe --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --eval --amp
```
### 2. 评估
#### 2.1 评估检测效果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml
```
**注意:**
- 评估检测使用的是```tools/eval.py```, 评估跟踪使用的是```tools/eval_mot.py```
#### 2.2 评估跟踪效果
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/botsort/botsort_ppyoloe.yml --scaled=True
```
**注意:**
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE YOLOv3则为False,如果使用通用检测模型则为True, 默认值是False。
- 跟踪结果会存于`{output_dir}/mot_results/`中,里面每个视频序列对应一个txt,每个txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`, 此外`{output_dir}`可通过`--output_dir`设置。
### 3. 导出预测模型
```bash
python tools/export_model.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml --output_dir=output_inference -o weights=https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
```
### 4. 用导出的模型基于Python去预测
```bash
# 下载demo视频
wget https://bj.bcebos.com/v1/paddledet/data/mot/demo/mot17_demo.mp4
CUDA_VISIBLE_DEVICES=0 python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half --tracker_config=deploy/pptracking/python/tracker_config.yml --video_file=mot17_demo.mp4 --device=GPU --threshold=0.5
```
**注意:**
- 运行前需要手动修改`tracker_config.yml`的跟踪器类型为`type: BOTSORTTracker`
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_mot_txt_per_img`(对每张图片保存一个txt)表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
- 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`
## 引用
```
@article{aharon2022bot,
title={BoT-SORT: Robust Associations Multi-Pedestrian Tracking},
author={Aharon, Nir and Orfaig, Roy and Bobrovsky, Ben-Zion},
journal={arXiv preprint arXiv:2206.14651},
year={2022}
}
```
# This config is an assembled config for ByteTrack MOT, used as eval/infer mode for MOT.
_BASE_: [
'../bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml',
'../bytetrack/_base_/mot17.yml',
'../bytetrack/_base_/ppyoloe_mot_reader_640x640.yml'
]
weights: output/botsort_ppyoloe/model_final
log_iter: 20
snapshot_epoch: 2
metric: MOT # eval/infer mode, set 'COCO' can be training mode
num_classes: 1
architecture: ByteTrack
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams
ByteTrack:
detector: YOLOv3 # PPYOLOe version
reid: None
tracker: BOTSORTTracker
det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
reid_weights: None
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
# Tracking requires higher quality boxes, so NMS score_threshold will be higher
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 # 100
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.1 # 0.01 in original detector
nms_threshold: 0.4 # 0.6 in original detector
BOTSORTTracker:
track_high_thresh: 0.3
track_low_thresh: 0.2
new_track_thresh: 0.4
match_thresh: 0.7
track_buffer: 30
min_box_area: 0
camera_motion: False
cmc_method: 'sparseOptFlow' # only camera_motion is True,
# sparseOptFlow | files (Vidstab GMC) | orb | ecc
# MOTDataset for MOT evaluation and inference
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT17/images/half
keep_ori_im: True # set as True in DeepSORT and ByteTrack
TestMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
keep_ori_im: True # set True if save visualization images or video
......@@ -60,13 +60,14 @@ OCSORTTracker:
vertical_ratio: 0
min_box_area: 0
use_byte: False
use_angle_cost: False
# MOTDataset for MOT evaluation and inference
EvalMOTDataset:
!MOTImageFolder
dataset_dir: dataset/mot
data_root: MOT17/images/half
data_root: MOT17_test/images/half
keep_ori_im: True # set as True in DeepSORT and ByteTrack
TestMOTDataset:
......
......@@ -57,7 +57,8 @@ def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i], i] for i in x if i >= 0]) #
match = np.array([[y[i], i] for i in x if i >= 0])
return match
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
......@@ -125,3 +126,44 @@ def associate(detections, trackers, iou_threshold, velocities, previous_obs,
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def associate_only_iou(detections, trackers, iou_threshold):
if (len(trackers) == 0):
return np.empty(
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
(0, 5), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if (d not in matched_indices[:, 0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if (t not in matched_indices[:, 1]):
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if (iou_matrix[m[0], m[1]] < iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if (len(matches) == 0):
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
......@@ -15,3 +15,5 @@
from . import kalman_filter
from .kalman_filter import *
from .gmc import *
from .ocsort_kalman_filter import *
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/gmc.py
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
class GMC:
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
super(GMC, self).__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
self.detector = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(
maxCorners=1000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04)
# self.gmc_file = open('GMC_results.txt', 'w')
elif self.method == 'file' or self.method == 'files':
seqName = verbose[0]
ablation = verbose[1]
if ablation:
filePath = r'tracker/GMC_files/MOT17_ablation'
else:
filePath = r'tracker/GMC_files/MOTChallenge'
if '-FRCNN' in seqName:
seqName = seqName[:-6]
elif '-DPM' in seqName:
seqName = seqName[:-4]
elif '-SDP' in seqName:
seqName = seqName[:-4]
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
if self.gmcFile is None:
raise ValueError("Error: Unable to open GMC file in directory:"
+ filePath)
elif self.method == 'none' or self.method == 'None':
self.method = 'none'
else:
raise ValueError("Error: Unknown CMC method:" + method)
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
if self.method == 'orb' or self.method == 'sift':
return self.applyFeaures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == 'file':
return self.applyFile(raw_frame, detections)
elif self.method == 'none':
return np.eye(2, 3)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
# Initialization done
self.initializedFirstFrame = True
return H
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc,
H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode,
self.criteria, None, 1)
except:
print('Warning: find transform failed. Set warp as identity')
return H
def applyFeaures(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(
0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
# compute the descriptors
keypoints, descriptors = self.extractor.compute(frame, keypoints)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
# Initialization done
self.initializedFirstFrame = True
return H
# Match descriptors.
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
matches = []
spatialDistances = []
maxSpatialDistance = 0.25 * np.array([width, height])
# Handle empty matches case
if len(knnMatches) == 0:
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
for m, n in knnMatches:
if m.distance < 0.9 * n.distance:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (
prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
spatialDistances.append(spatialDistance)
matches.append(m)
meanSpatialDistances = np.mean(spatialDistances, 0)
stdSpatialDistances = np.std(spatialDistances, 0)
inliesrs = (spatialDistances - meanSpatialDistances
) < 2.5 * stdSpatialDistances
goodMatches = []
prevPoints = []
currPoints = []
for i in range(len(matches)):
if inliesrs[i, 0] and inliesrs[i, 1]:
goodMatches.append(matches[i])
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
currPoints.append(keypoints[matches[i].trainIdx].pt)
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Draw the keypoint matches on the output image
if 0:
matches_img = np.hstack((self.prevFrame, frame))
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
W = np.size(self.prevFrame, 1)
for m in goodMatches:
prev_pt = np.array(
self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
curr_pt[0] += W
color = np.random.randint(0, 255, (3, ))
color = (int(color[0]), int(color[1]), int(color[2]))
matches_img = cv2.line(matches_img, prev_pt, curr_pt,
tuple(color), 1, cv2.LINE_AA)
matches_img = cv2.circle(matches_img, prev_pt, 2,
tuple(color), -1)
matches_img = cv2.circle(matches_img, curr_pt, 2,
tuple(color), -1)
plt.figure()
plt.imshow(matches_img)
plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
def applySparseOptFlow(self, raw_frame, detections=None):
t0 = time.time()
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
# find the keypoints
keypoints = cv2.goodFeaturesToTrack(
frame, mask=None, **self.feature_params)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
if self.prevFrame.shape != frame.shape:
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
return H
# find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
self.prevFrame, frame, self.prevKeyPoints, None)
# leave good correspondences only
prevPoints = []
currPoints = []
for i in range(len(status)):
if status[i]:
prevPoints.append(self.prevKeyPoints[i])
currPoints.append(matchedKeypoints[i])
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
t1 = time.time()
# gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
# self.gmc_file.write(gmc_line)
return H
def applyFile(self, raw_frame, detections=None):
line = self.gmcFile.readline()
tokens = line.split("\t")
H = np.eye(2, 3, dtype=np.float_)
H[0, 0] = float(tokens[1])
H[0, 1] = float(tokens[2])
H[0, 2] = float(tokens[3])
H[1, 0] = float(tokens[4])
H[1, 1] = float(tokens[5])
H[1, 2] = float(tokens[6])
return H
......@@ -18,6 +18,39 @@ This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/
import numpy as np
import scipy.linalg
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_project(mean, covariance, std, _update_mat):
innovation_cov = np.diag(np.square(std))
mean = np.dot(_update_mat, mean)
covariance = np.dot(np.dot(_update_mat, covariance), _update_mat.T)
return mean, covariance + innovation_cov
@nb.njit(fastmath=True, cache=True)
def nb_multi_predict(mean, covariance, motion_cov, motion_mat):
mean = np.dot(mean, motion_mat.T)
left = np.dot(motion_mat, covariance)
covariance = np.dot(left, motion_mat.T) + motion_cov
return mean, covariance
@nb.njit(fastmath=True, cache=True)
def nb_update(mean, covariance, proj_mean, proj_cov, measurement, meas_mat):
kalman_gain = np.linalg.solve(proj_cov, (covariance @meas_mat.T).T).T
innovation = measurement - proj_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @proj_cov @kalman_gain.T
return mean, covariance
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
__all__ = ['KalmanFilter']
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
......@@ -59,10 +92,10 @@ class KalmanFilter(object):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
self._motion_mat = np.eye(2 * ndim, 2 * ndim, dtype=np.float32)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
self._update_mat = np.eye(ndim, 2 * ndim, dtype=np.float32)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
......@@ -96,7 +129,7 @@ class KalmanFilter(object):
10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, covariance
return mean, np.float32(covariance)
def predict(self, mean, covariance):
"""
......@@ -140,10 +173,16 @@ class KalmanFilter(object):
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
]
std = np.array(
[
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
],
dtype=np.float32)
if use_numba:
return nb_project(mean, covariance, std, self._update_mat)
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
......@@ -165,18 +204,29 @@ class KalmanFilter(object):
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
std_pos = np.array([
self._std_weight_position * mean[:, 3], self._std_weight_position *
mean[:, 3], 1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]
]
std_vel = [
])
std_vel = np.array([
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity *
mean[:, 3], 1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]
]
])
sqr = np.square(np.r_[std_pos, std_vel]).T
if use_numba:
means = []
covariances = []
for i in range(len(mean)):
a, b = nb_multi_predict(mean[i], covariance[i],
np.diag(sqr[i]), self._motion_mat)
means.append(a)
covariances.append(b)
return np.asarray(means), np.asarray(covariances)
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
......@@ -204,18 +254,17 @@ class KalmanFilter(object):
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(
projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
if use_numba:
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot(
(kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
return nb_update(mean, covariance, projected_mean, projected_cov,
measurement, self._update_mat)
kalman_gain = np.linalg.solve(projected_cov,
(covariance @self._update_mat.T).T).T
innovation = measurement - projected_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @projected_cov @kalman_gain.T
return mean, covariance
def gating_distance(self,
mean,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/danbochman/SORT/blob/danny_opencv/kalman_filter.py
"""
import numpy as np
from numpy import dot, zeros, eye
from numpy.linalg import inv
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_predict(x, F, P, Q):
x = dot(F, x)
P = dot(dot(F, P), F.T) + Q
return x, P
@nb.njit(fastmath=True, cache=True)
def nb_update(x, z, H, P, R, _I):
y = z - np.dot(H, x)
PHT = dot(P, H.T)
S = dot(H, PHT) + R
K = dot(PHT, inv(S))
x = x + dot(K, y)
I_KH = _I - dot(K, H)
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
return x, P
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
class OCSORTKalmanFilter:
def __init__(self, dim_x, dim_z):
self.dim_x = dim_x
self.dim_z = dim_z
self.x = zeros((dim_x, 1))
self.P = eye(dim_x)
self.Q = eye(dim_x)
self.F = eye(dim_x)
self.H = zeros((dim_z, dim_x))
self.R = eye(dim_z)
self.M = zeros((dim_z, dim_z))
self._I = eye(dim_x)
def predict(self):
if use_numba:
self.x, self.P = nb_predict(self.x, self.F, self.P, self.Q)
else:
self.x = dot(self.F, self.x)
self.P = dot(dot(self.F, self.P), self.F.T) + self.Q
def update(self, z):
if z is None:
return
if use_numba:
self.x, self.P = nb_update(self.x, z, self.H, self.P, self.R,
self._I)
else:
y = z - np.dot(self.H, self.x)
PHT = dot(self.P, self.H.T)
S = dot(self.H, PHT) + self.R
K = dot(PHT, inv(S))
self.x = self.x + dot(K, y)
I_KH = self._I - dot(K, self.H)
self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(K, self.R), K.T)
......@@ -26,4 +26,5 @@ from . import center_tracker
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
from .botsort_tracker import *
from .center_tracker import *
......@@ -147,6 +147,24 @@ class STrack(BaseTrack):
tracks[i].mean = mean
tracks[i].covariance = cov
@staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
R = H[:2, :2]
R8x8 = np.kron(np.eye(4, dtype=float), R)
t = H[:2, 2]
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
mean = R8x8.dot(mean)
mean[:2] += t
cov = R8x8.dot(cov).dot(R8x8.transpose())
stracks[i].mean = mean
stracks[i].covariance = cov
def reset_track_id(self):
self.reset_track_count(self.cls_id)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/bot_sort.py
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
from collections import deque
from ..matching import jde_matching as matching
from ..motion import GMC
from .base_jde_tracker import TrackState, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
from ..motion import KalmanFilter
class BOTSORTTracker(object):
"""
BOTSORT tracker, support single class
Args:
track_high_thresh (float): threshold of detection high score
track_low_thresh (float): threshold of remove detection score
new_track_thresh (float): threshold of new track score
match_thresh (float): iou threshold for associate
track_buffer (int): tracking reserved frames,default 30
min_box_area (float): reserved min box
camera_motion (bool): Whether use camera motion, default False
cmc_method (str): camera motion method,defalut sparseOptFlow
frame_rate (int): fps buffer_size=int(frame_rate / 30.0 * track_buffer)
"""
def __init__(self,
track_high_thresh=0.3,
track_low_thresh=0.2,
new_track_thresh=0.4,
match_thresh=0.7,
track_buffer=30,
min_box_area=0,
camera_motion=False,
cmc_method='sparseOptFlow',
frame_rate=30):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
self.frame_id = 0
self.track_high_thresh = track_high_thresh
self.track_low_thresh = track_low_thresh
self.new_track_thresh = new_track_thresh
self.match_thresh = match_thresh
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()
self.min_box_area = min_box_area
self.camera_motion = camera_motion
self.gmc = GMC(method=cmc_method)
def update(self, output_results, img=None):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
if len(output_results):
bboxes = output_results[:, 2:6]
scores = output_results[:, 1]
classes = output_results[:, 0]
# Remove bad detections
lowest_inds = scores > self.track_low_thresh
bboxes = bboxes[lowest_inds]
scores = scores[lowest_inds]
classes = classes[lowest_inds]
# Find high threshold detections
remain_inds = scores > self.track_high_thresh
dets = bboxes[remain_inds]
scores_keep = scores[remain_inds]
classes_keep = classes[remain_inds]
else:
bboxes = []
scores = []
classes = []
dets = []
scores_keep = []
classes_keep = []
if len(dets) > 0:
'''Detections'''
detections = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
for (tlbr, s, c) in zip(dets, scores_keep, classes_keep)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
''' Step 2: First association, with high score detection boxes'''
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.kalman_filter)
# Fix camera motion
if self.camera_motion:
warp = self.gmc.apply(img[0], dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
# Associate with high score detection boxes
ious_dists = matching.iou_distance(strack_pool, detections)
matches, u_track, u_detection = matching.linear_assignment(
ious_dists, thresh=self.match_thresh)
for itracked, idet in matches:
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
''' Step 3: Second association, with low score detection boxes'''
if len(scores):
inds_high = scores < self.track_high_thresh
inds_low = scores > self.track_low_thresh
inds_second = np.logical_and(inds_low, inds_high)
dets_second = bboxes[inds_second]
scores_second = scores[inds_second]
classes_second = classes[inds_second]
else:
dets_second = []
scores_second = []
classes_second = []
# association the untrack to the low score detections
if len(dets_second) > 0:
'''Detections'''
detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in
zip(dets_second, scores_second, classes_second)
]
else:
detections_second = []
r_tracked_stracks = [
strack_pool[i] for i in u_track
if strack_pool[i].state == TrackState.Tracked
]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.5)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.new_track_thresh:
continue
track.activate(self.kalman_filter, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
""" Merge """
self.tracked_stracks = [
t for t in self.tracked_stracks if t.state == TrackState.Tracked
]
self.tracked_stracks = joint_stracks(self.tracked_stracks,
activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks,
refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
self.tracked_stracks, self.lost_stracks)
# output_stracks = [track for track in self.tracked_stracks if track.is_activated]
output_stracks = [track for track in self.tracked_stracks]
return output_stracks
......@@ -15,16 +15,10 @@
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py
"""
import time
import numpy as np
try:
from filterpy.kalman import KalmanFilter
except:
print(
'Warning: Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
)
pass
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch, associate_only_iou
from ..motion.ocsort_kalman_filter import OCSORTKalmanFilter
def k_previous_obs(observations, cur_age, k):
......@@ -90,19 +84,14 @@ class KalmanBoxTracker(object):
count = 0
def __init__(self, bbox, delta_t=3):
try:
from filterpy.kalman import KalmanFilter
except Exception as e:
raise RuntimeError(
'Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
)
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
self.kf = OCSORTKalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1., 0, 0, 0, 1., 0, 0], [0, 1., 0, 0, 0, 1., 0],
[0, 0, 1., 0, 0, 0, 1.], [0, 0, 0, 1., 0, 0, 0],
[0, 0, 0, 0, 1., 0, 0], [0, 0, 0, 0, 0, 1., 0],
[0, 0, 0, 0, 0, 0, 1.]])
self.kf.H = np.array([[1., 0, 0, 0, 0, 0, 0], [0, 1., 0, 0, 0, 0, 0],
[0, 0, 1., 0, 0, 0, 0], [0, 0, 0, 1., 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000.
# give high uncertainty to the unobservable initial velocities
......@@ -130,12 +119,13 @@ class KalmanBoxTracker(object):
self.velocity = None
self.delta_t = delta_t
def update(self, bbox):
def update(self, bbox, angle_cost=False):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
if self.last_observation.sum() >= 0: # no previous observation
if angle_cost and self.last_observation.sum(
) >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
......@@ -144,9 +134,9 @@ class KalmanBoxTracker(object):
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
# """
# Estimate the track speed direction with observations \Delta t steps away
# """
self.velocity = speed_direction(previous_box, bbox)
"""
Insert new observations. This is a ugly way to maintain both self.observations
......@@ -199,6 +189,7 @@ class OCSORTTracker(object):
1.6 for pedestrian tracking.
min_box_area (int): min box area to filter out low quality boxes
use_byte (bool): Whether use ByteTracker, default False
use_angle_cost (bool) Whether use angle cost, default False
"""
def __init__(self,
......@@ -210,7 +201,8 @@ class OCSORTTracker(object):
inertia=0.2,
vertical_ratio=-1,
min_box_area=0,
use_byte=False):
use_byte=False,
use_angle_cost=False):
self.det_thresh = det_thresh
self.max_age = max_age
self.min_hits = min_hits
......@@ -220,6 +212,7 @@ class OCSORTTracker(object):
self.vertical_ratio = vertical_ratio
self.min_box_area = min_box_area
self.use_byte = use_byte
self.use_angle_cost = use_angle_cost
self.trackers = []
self.frame_count = 0
......@@ -267,23 +260,31 @@ class OCSORTTracker(object):
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
if self.use_angle_cost:
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
if self.use_angle_cost:
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
else:
matched, unmatched_dets, unmatched_trks = associate_only_iou(
dets, trks, self.iou_threshold)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
self.trackers[m[1]].update(
dets[m[0], :], angle_cost=self.use_angle_cost)
"""
Second round of associaton by OCR
"""
......@@ -307,7 +308,8 @@ class OCSORTTracker(object):
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :])
self.trackers[trk_ind].update(
dets_second[det_ind, :], angle_cost=self.use_angle_cost)
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
......@@ -331,7 +333,8 @@ class OCSORTTracker(object):
1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :])
self.trackers[trk_ind].update(
dets[det_ind, :], angle_cost=self.use_angle_cost)
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets,
......@@ -346,6 +349,7 @@ class OCSORTTracker(object):
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
......
......@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path)
from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video
from mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker
from mot.tracker import JDETracker, DeepSORTTracker, OCSORTTracker, BOTSORTTracker
from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic
from mot.visualize import plot_tracking, plot_tracking_dict
......@@ -63,6 +63,7 @@ class SDE_Detector(Detector):
draw_center_traj (bool): Whether drawing the trajectory of center, default as False
secs_interval (int): The seconds interval to count after tracking, default as 10
skip_frame_num (int): Skip frame num to get faster MOT results, default as -1
warmup_frame (int):Warmup frame num to test speed of MOT,default as 50
do_entrance_counting(bool): Whether counting the numbers of identifiers entering
or getting out from the entrance, default as False,only support single class
counting in MOT, and the video should be taken by a static camera.
......@@ -98,6 +99,7 @@ class SDE_Detector(Detector):
draw_center_traj=False,
secs_interval=10,
skip_frame_num=-1,
warmup_frame=50,
do_entrance_counting=False,
do_break_in_counting=False,
region_type='horizontal',
......@@ -122,6 +124,7 @@ class SDE_Detector(Detector):
self.draw_center_traj = draw_center_traj
self.secs_interval = secs_interval
self.skip_frame_num = skip_frame_num
self.warmup_frame = warmup_frame
self.do_entrance_counting = do_entrance_counting
self.do_break_in_counting = do_break_in_counting
self.region_type = region_type
......@@ -168,6 +171,8 @@ class SDE_Detector(Detector):
'type'] == 'DeepSORTTracker' else False
self.use_ocsort_tracker = True if tracker_cfg[
'type'] == 'OCSORTTracker' else False
self.use_botsort_tracker = True if tracker_cfg[
'type'] == 'BOTSORTTracker' else False
if self.use_deepsort_tracker:
if self.reid_pred_config is not None and hasattr(
......@@ -198,6 +203,7 @@ class SDE_Detector(Detector):
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
use_byte = cfg.get('use_byte', False)
use_angle_cost = cfg.get('use_angle_cost', False)
self.tracker = OCSORTTracker(
det_thresh=det_thresh,
......@@ -208,7 +214,27 @@ class SDE_Detector(Detector):
inertia=inertia,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
use_byte=use_byte)
use_byte=use_byte,
use_angle_cost=use_angle_cost)
elif self.use_botsort_tracker:
track_high_thresh = cfg.get('track_high_thresh', 0.3)
track_low_thresh = cfg.get('track_low_thresh', 0.2)
new_track_thresh = cfg.get('new_track_thresh', 0.4)
match_thresh = cfg.get('match_thresh', 0.7)
track_buffer = cfg.get('track_buffer', 30)
camera_motion = cfg.get('camera_motion', False)
cmc_method = cfg.get('cmc_method', 'sparseOptFlow')
self.tracker = BOTSORTTracker(
track_high_thresh=track_high_thresh,
track_low_thresh=track_low_thresh,
new_track_thresh=new_track_thresh,
match_thresh=match_thresh,
track_buffer=track_buffer,
camera_motion=camera_motion,
cmc_method=cmc_method)
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
......@@ -283,7 +309,7 @@ class SDE_Detector(Detector):
det_results['embeddings'] = pred_embs
return det_results
def tracking(self, det_results):
def tracking(self, det_results, img=None):
pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results.get('embeddings', None)
......@@ -357,6 +383,29 @@ class SDE_Detector(Detector):
}
return tracking_outs
elif self.use_botsort_tracker:
# use BOTSORTTracker, only support singe class
online_targets = self.tracker.update(pred_dets, img)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
online_tlwhs[0].append(tlwh)
online_ids[0].append(tid)
online_scores[0].append(tscore)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list)
......@@ -420,7 +469,8 @@ class SDE_Detector(Detector):
repeats=1,
visual=True,
seq_name=None,
reuse_det_result=False):
reuse_det_result=False,
frame_count=0):
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
......@@ -459,7 +509,10 @@ class SDE_Detector(Detector):
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
result_warmup = self.tracking(det_result)
if self.use_botsort_tracker:
result_warmup = self.tracking(det_result, batch_image_list)
else:
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result = self.reidprocess(det_result)
......@@ -473,35 +526,44 @@ class SDE_Detector(Detector):
self.gpu_util += gu
else:
self.det_times.preprocess_time_s.start()
if frame_count > self.warmup_frame:
self.det_times.preprocess_time_s.start()
if not reuse_det_result:
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
if frame_count > self.warmup_frame:
self.det_times.preprocess_time_s.end()
if frame_count > self.warmup_frame:
self.det_times.inference_time_s.start()
if not reuse_det_result:
result = self.predict()
self.det_times.inference_time_s.end()
self.det_times.postprocess_time_s.start()
if frame_count > self.warmup_frame:
self.det_times.inference_time_s.end()
if frame_count > self.warmup_frame:
self.det_times.postprocess_time_s.start()
if not reuse_det_result:
det_result = self.postprocess(inputs, result)
self.previous_det_result = det_result
else:
assert self.previous_det_result is not None
det_result = self.previous_det_result
self.det_times.postprocess_time_s.end()
if frame_count > self.warmup_frame:
self.det_times.postprocess_time_s.end()
# tracking process
self.det_times.tracking_time_s.start()
if frame_count > self.warmup_frame:
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
if self.use_botsort_tracker:
tracking_outs = self.tracking(det_result, batch_image_list)
else:
tracking_outs = self.tracking(det_result)
if frame_count > self.warmup_frame:
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
......@@ -623,7 +685,8 @@ class SDE_Detector(Detector):
[frame],
visual=False,
seq_name=seq_name,
reuse_det_result=reuse_det_result)
reuse_det_result=reuse_det_result,
frame_count=frame_id)
timer.toc()
# bs=1 in MOT model
......@@ -652,7 +715,7 @@ class SDE_Detector(Detector):
records = statistic['records']
fps = 1. / timer.duration
if self.use_deepsort_tracker or self.use_ocsort_tracker:
if self.use_deepsort_tracker or self.use_ocsort_tracker or self.use_botsort_tracker:
# use DeepSORTTracker or OCSORTTracker, only support singe class
if isinstance(online_tlwhs, defaultdict):
online_tlwhs = online_tlwhs[0]
......@@ -840,6 +903,7 @@ def main():
draw_center_traj=FLAGS.draw_center_traj,
secs_interval=FLAGS.secs_interval,
skip_frame_num=FLAGS.skip_frame_num,
warmup_frame=FLAGS.warmup_frame,
do_entrance_counting=FLAGS.do_entrance_counting,
do_break_in_counting=FLAGS.do_break_in_counting,
region_type=FLAGS.region_type,
......@@ -850,6 +914,7 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
detector.det_times.info(average=True)
elif FLAGS.mtmct_dir is not None:
with open(FLAGS.mtmct_cfg) as f:
mtmct_cfg = yaml.safe_load(f)
......
......@@ -142,6 +142,12 @@ def argsparser():
type=int,
default=-1,
help='Skip frames to speed up the process of getting mot results.')
parser.add_argument(
'--warmup_frame',
type=int,
default=50,
help='Warmup frames to test speed of the process of getting mot results.'
)
parser.add_argument(
"--do_entrance_counting",
action='store_true',
......
......@@ -2,7 +2,7 @@
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
type: OCSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker', 'DeepSORTTracker']
type: BOTSORTTracker # choose one tracker in ['JDETracker', 'OCSORTTracker', 'DeepSORTTracker','BOTSORTTracker']
# When using for MTMCT(Multi-Target Multi-Camera Tracking), you should modify to 'DeepSORTTracker'
......@@ -28,6 +28,7 @@ OCSORTTracker:
min_box_area: 0
vertical_ratio: 0
use_byte: False
use_angle_cost: False
# used for DeepSORT and MTMCT in PP-Tracking project
......@@ -41,3 +42,14 @@ DeepSORTTracker:
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
BOTSORTTracker:
track_high_thresh: 0.3
track_low_thresh: 0.2
new_track_thresh: 0.4
match_thresh: 0.7
track_buffer: 30
min_box_area: 0
camera_motion: False
cmc_method: 'sparseOptFlow' # only camera_motion is True,
# sparseOptFlow | files (Vidstab GMC) | orb | ecc
......@@ -280,4 +280,4 @@ def create(cls_or_name, **kwargs):
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return cls(**cls_kwargs)
return cls(**cls_kwargs)
\ No newline at end of file
......@@ -30,7 +30,7 @@ from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
from ppdet.modeling.mot.tracker import JDETracker, CenterTracker
from ppdet.modeling.mot.tracker import DeepSORTTracker, OCSORTTracker
from ppdet.modeling.mot.tracker import DeepSORTTracker, OCSORTTracker, BOTSORTTracker
from ppdet.modeling.architectures import YOLOX
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
from ppdet.data.source.category import get_categories
......@@ -453,6 +453,29 @@ class Tracker(object):
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes, self.ids2names)
elif isinstance(tracker, BOTSORTTracker):
# BOTSORT Tracker
online_targets = tracker.update(
pred_dets_old, img=ori_image.numpy())
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] > 0:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
# save results
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
save_vis_results(data, frame_id, online_ids, online_tlwhs,
online_scores, timer.average_time, show_image,
save_dir, self.cfg.num_classes, self.ids2names)
else:
raise ValueError(tracker)
frame_id += 1
......
......@@ -122,3 +122,44 @@ def associate(detections, trackers, iou_threshold, velocities, previous_obs,
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def associate_only_iou(detections, trackers, iou_threshold):
if (len(trackers) == 0):
return np.empty(
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
(0, 5), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if (d not in matched_indices[:, 0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if (t not in matched_indices[:, 1]):
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if (iou_matrix[m[0], m[1]] < iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if (len(matches) == 0):
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
......@@ -15,3 +15,4 @@
from . import kalman_filter
from .kalman_filter import *
from .gmc import *
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/gmc.py
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
from ppdet.core.workspace import register, serializable
@register
@serializable
class GMC:
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
super(GMC, self).__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
self.detector = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(
maxCorners=1000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04)
# self.gmc_file = open('GMC_results.txt', 'w')
elif self.method == 'file' or self.method == 'files':
seqName = verbose[0]
ablation = verbose[1]
if ablation:
filePath = r'tracker/GMC_files/MOT17_ablation'
else:
filePath = r'tracker/GMC_files/MOTChallenge'
if '-FRCNN' in seqName:
seqName = seqName[:-6]
elif '-DPM' in seqName:
seqName = seqName[:-4]
elif '-SDP' in seqName:
seqName = seqName[:-4]
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
if self.gmcFile is None:
raise ValueError("Error: Unable to open GMC file in directory:"
+ filePath)
elif self.method == 'none' or self.method == 'None':
self.method = 'none'
else:
raise ValueError("Error: Unknown CMC method:" + method)
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
if self.method == 'orb' or self.method == 'sift':
return self.applyFeaures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == 'file':
return self.applyFile(raw_frame, detections)
elif self.method == 'none':
return np.eye(2, 3)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
# Initialization done
self.initializedFirstFrame = True
return H
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc,
H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode,
self.criteria, None, 1)
except:
print('Warning: find transform failed. Set warp as identity')
return H
def applyFeaures(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(
0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
# compute the descriptors
keypoints, descriptors = self.extractor.compute(frame, keypoints)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
# Initialization done
self.initializedFirstFrame = True
return H
# Match descriptors.
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
matches = []
spatialDistances = []
maxSpatialDistance = 0.25 * np.array([width, height])
# Handle empty matches case
if len(knnMatches) == 0:
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
for m, n in knnMatches:
if m.distance < 0.9 * n.distance:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (
prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
spatialDistances.append(spatialDistance)
matches.append(m)
meanSpatialDistances = np.mean(spatialDistances, 0)
stdSpatialDistances = np.std(spatialDistances, 0)
inliesrs = (spatialDistances - meanSpatialDistances
) < 2.5 * stdSpatialDistances
goodMatches = []
prevPoints = []
currPoints = []
for i in range(len(matches)):
if inliesrs[i, 0] and inliesrs[i, 1]:
goodMatches.append(matches[i])
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
currPoints.append(keypoints[matches[i].trainIdx].pt)
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Draw the keypoint matches on the output image
if 0:
matches_img = np.hstack((self.prevFrame, frame))
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
W = np.size(self.prevFrame, 1)
for m in goodMatches:
prev_pt = np.array(
self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
curr_pt[0] += W
color = np.random.randint(0, 255, (3, ))
color = (int(color[0]), int(color[1]), int(color[2]))
matches_img = cv2.line(matches_img, prev_pt, curr_pt,
tuple(color), 1, cv2.LINE_AA)
matches_img = cv2.circle(matches_img, prev_pt, 2,
tuple(color), -1)
matches_img = cv2.circle(matches_img, curr_pt, 2,
tuple(color), -1)
plt.figure()
plt.imshow(matches_img)
plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
def applySparseOptFlow(self, raw_frame, detections=None):
t0 = time.time()
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
# find the keypoints
keypoints = cv2.goodFeaturesToTrack(
frame, mask=None, **self.feature_params)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
if self.prevFrame.shape != frame.shape:
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
return H
# find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
self.prevFrame, frame, self.prevKeyPoints, None)
# leave good correspondences only
prevPoints = []
currPoints = []
for i in range(len(status)):
if status[i]:
prevPoints.append(self.prevKeyPoints[i])
currPoints.append(matchedKeypoints[i])
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
t1 = time.time()
# gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
# self.gmc_file.write(gmc_line)
return H
def applyFile(self, raw_frame, detections=None):
line = self.gmcFile.readline()
tokens = line.split("\t")
H = np.eye(2, 3, dtype=np.float_)
H[0, 0] = float(tokens[1])
H[0, 1] = float(tokens[2])
H[0, 2] = float(tokens[3])
H[1, 0] = float(tokens[4])
H[1, 1] = float(tokens[5])
H[1, 2] = float(tokens[6])
return H
......@@ -17,7 +17,39 @@ This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/
import numpy as np
import scipy.linalg
from ppdet.core.workspace import register, serializable
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_project(mean, covariance, std, _update_mat):
innovation_cov = np.diag(np.square(std))
mean = np.dot(_update_mat, mean)
covariance = np.dot(np.dot(_update_mat, covariance), _update_mat.T)
return mean, covariance + innovation_cov
@nb.njit(fastmath=True, cache=True)
def nb_multi_predict(mean, covariance, motion_cov, motion_mat):
mean = np.dot(mean, motion_mat.T)
left = np.dot(motion_mat, covariance)
covariance = np.dot(left, motion_mat.T) + motion_cov
return mean, covariance
@nb.njit(fastmath=True, cache=True)
def nb_update(mean, covariance, proj_mean, proj_cov, measurement, meas_mat):
kalman_gain = np.linalg.solve(proj_cov, (covariance @meas_mat.T).T).T
innovation = measurement - proj_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @proj_cov @kalman_gain.T
return mean, covariance
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
__all__ = ['KalmanFilter']
"""
......@@ -39,8 +71,6 @@ chi2inv95 = {
}
@register
@serializable
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
......@@ -62,10 +92,10 @@ class KalmanFilter(object):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
self._motion_mat = np.eye(2 * ndim, 2 * ndim, dtype=np.float32)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
self._update_mat = np.eye(ndim, 2 * ndim, dtype=np.float32)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
......@@ -99,7 +129,7 @@ class KalmanFilter(object):
10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, covariance
return mean, np.float32(covariance)
def predict(self, mean, covariance):
"""
......@@ -143,10 +173,16 @@ class KalmanFilter(object):
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
]
std = np.array(
[
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
],
dtype=np.float32)
if use_numba:
return nb_project(mean, covariance, std, self._update_mat)
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
......@@ -168,18 +204,29 @@ class KalmanFilter(object):
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
std_pos = np.array([
self._std_weight_position * mean[:, 3], self._std_weight_position *
mean[:, 3], 1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]
]
std_vel = [
])
std_vel = np.array([
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity *
mean[:, 3], 1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]
]
])
sqr = np.square(np.r_[std_pos, std_vel]).T
if use_numba:
means = []
covariances = []
for i in range(len(mean)):
a, b = nb_multi_predict(mean[i], covariance[i],
np.diag(sqr[i]), self._motion_mat)
means.append(a)
covariances.append(b)
return np.asarray(means), np.asarray(covariances)
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
......@@ -207,18 +254,17 @@ class KalmanFilter(object):
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(
projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
if use_numba:
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot(
(kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
return nb_update(mean, covariance, projected_mean, projected_cov,
measurement, self._update_mat)
kalman_gain = np.linalg.solve(projected_cov,
(covariance @self._update_mat.T).T).T
innovation = measurement - projected_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @projected_cov @kalman_gain.T
return mean, covariance
def gating_distance(self,
mean,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/danbochman/SORT/blob/danny_opencv/kalman_filter.py
"""
import numpy as np
from numpy import dot, zeros, eye
from numpy.linalg import inv
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_predict(x, F, P, Q):
x = dot(F, x)
P = dot(dot(F, P), F.T) + Q
return x, P
@nb.njit(fastmath=True, cache=True)
def nb_update(x, z, H, P, R, _I):
y = z - np.dot(H, x)
PHT = dot(P, H.T)
S = dot(H, PHT) + R
K = dot(PHT, inv(S))
x = x + dot(K, y)
I_KH = _I - dot(K, H)
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
return x, P
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
class OCSORTKalmanFilter:
def __init__(self, dim_x, dim_z):
self.dim_x = dim_x
self.dim_z = dim_z
self.x = zeros((dim_x, 1))
self.P = eye(dim_x)
self.Q = eye(dim_x)
self.F = eye(dim_x)
self.H = zeros((dim_z, dim_x))
self.R = eye(dim_z)
self.M = zeros((dim_z, dim_z))
self._I = eye(dim_x)
def predict(self):
if use_numba:
self.x, self.P = nb_predict(self.x, self.F, self.P, self.Q)
else:
self.x = dot(self.F, self.x)
self.P = dot(dot(self.F, self.P), self.F.T) + self.Q
def update(self, z):
if z is None:
return
if use_numba:
self.x, self.P = nb_update(self.x, z, self.H, self.P, self.R,
self._I)
else:
y = z - np.dot(self.H, self.x)
PHT = dot(self.P, self.H.T)
S = dot(self.H, PHT) + self.R
K = dot(PHT, inv(S))
self.x = self.x + dot(K, y)
I_KH = self._I - dot(K, self.H)
self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(K, self.R), K.T)
......@@ -26,4 +26,5 @@ from . import center_tracker
from .jde_tracker import *
from .deepsort_tracker import *
from .ocsort_tracker import *
from .botsort_tracker import *
from .center_tracker import *
......@@ -154,6 +154,24 @@ class STrack(BaseTrack):
tracks[i].mean = mean
tracks[i].covariance = cov
@staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
R = H[:2, :2]
R8x8 = np.kron(np.eye(4, dtype=float), R)
t = H[:2, 2]
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
mean = R8x8.dot(mean)
mean[:2] += t
cov = R8x8.dot(cov).dot(R8x8.transpose())
stracks[i].mean = mean
stracks[i].covariance = cov
def reset_track_id(self):
self.reset_track_count(self.cls_id)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/bot_sort.py
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
from collections import deque
from ..matching import jde_matching as matching
from ..motion import GMC
from .base_jde_tracker import TrackState, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
from ..motion import KalmanFilter
from ppdet.core.workspace import register, serializable
@register
@serializable
class BOTSORTTracker(object):
"""
BOTSORT tracker, support single class
Args:
track_high_thresh (float): threshold of detection high score
track_low_thresh (float): threshold of remove detection score
new_track_thresh (float): threshold of new track score
match_thresh (float): iou threshold for associate
track_buffer (int): tracking reserved frames,default 30
min_box_area (float): reserved min box
camera_motion (bool): Whether use camera motion, default False
cmc_method (str): camera motion method,defalut sparseOptFlow
frame_rate (int): fps buffer_size=int(frame_rate / 30.0 * track_buffer)
"""
def __init__(self,
track_high_thresh=0.3,
track_low_thresh=0.2,
new_track_thresh=0.4,
match_thresh=0.7,
track_buffer=30,
min_box_area=0,
camera_motion=False,
cmc_method='sparseOptFlow',
frame_rate=30):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
self.frame_id = 0
self.track_high_thresh = track_high_thresh
self.track_low_thresh = track_low_thresh
self.new_track_thresh = new_track_thresh
self.match_thresh = match_thresh
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()
self.min_box_area = min_box_area
self.camera_motion = camera_motion
self.gmc = GMC(method=cmc_method)
def update(self, output_results, img=None):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
if len(output_results):
bboxes = output_results[:, 2:6]
scores = output_results[:, 1]
classes = output_results[:, 0]
# Remove bad detections
lowest_inds = scores > self.track_low_thresh
bboxes = bboxes[lowest_inds]
scores = scores[lowest_inds]
classes = classes[lowest_inds]
# Find high threshold detections
remain_inds = scores > self.track_high_thresh
dets = bboxes[remain_inds]
scores_keep = scores[remain_inds]
classes_keep = classes[remain_inds]
else:
bboxes = []
scores = []
classes = []
dets = []
scores_keep = []
classes_keep = []
if len(dets) > 0:
'''Detections'''
detections = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
for (tlbr, s, c) in zip(dets, scores_keep, classes_keep)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
''' Step 2: First association, with high score detection boxes'''
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.kalman_filter)
# Fix camera motion
if self.camera_motion:
warp = self.gmc.apply(img[0], dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
# Associate with high score detection boxes
ious_dists = matching.iou_distance(strack_pool, detections)
matches, u_track, u_detection = matching.linear_assignment(
ious_dists, thresh=self.match_thresh)
for itracked, idet in matches:
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
''' Step 3: Second association, with low score detection boxes'''
if len(scores):
inds_high = scores < self.track_high_thresh
inds_low = scores > self.track_low_thresh
inds_second = np.logical_and(inds_low, inds_high)
dets_second = bboxes[inds_second]
scores_second = scores[inds_second]
classes_second = classes[inds_second]
else:
dets_second = []
scores_second = []
classes_second = []
# association the untrack to the low score detections
if len(dets_second) > 0:
'''Detections'''
detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for (tlbr, s, c) in
zip(dets_second, scores_second, classes_second)
]
else:
detections_second = []
r_tracked_stracks = [
strack_pool[i] for i in u_track
if strack_pool[i].state == TrackState.Tracked
]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.5)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.new_track_thresh:
continue
track.activate(self.kalman_filter, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
""" Merge """
self.tracked_stracks = [
t for t in self.tracked_stracks if t.state == TrackState.Tracked
]
self.tracked_stracks = joint_stracks(self.tracked_stracks,
activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks,
refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
self.tracked_stracks, self.lost_stracks)
# output_stracks = [track for track in self.tracked_stracks if track.is_activated]
output_stracks = [track for track in self.tracked_stracks]
return output_stracks
......@@ -16,15 +16,8 @@ This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/oc
"""
import numpy as np
try:
from filterpy.kalman import KalmanFilter
except:
print(
'Warning: Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
)
pass
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch
from ..matching.ocsort_matching import associate, linear_assignment, iou_batch, associate_only_iou
from ..motion.ocsort_kalman_filter import OCSORTKalmanFilter
from ppdet.core.workspace import register, serializable
......@@ -91,19 +84,14 @@ class KalmanBoxTracker(object):
count = 0
def __init__(self, bbox, delta_t=3):
try:
from filterpy.kalman import KalmanFilter
except Exception as e:
raise RuntimeError(
'Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
)
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
self.kf = OCSORTKalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1., 0, 0, 0, 1., 0, 0], [0, 1., 0, 0, 0, 1., 0],
[0, 0, 1., 0, 0, 0, 1], [0, 0, 0, 1., 0, 0, 0],
[0, 0, 0, 0, 1., 0, 0], [0, 0, 0, 0, 0, 1., 0],
[0, 0, 0, 0, 0, 0, 1.]])
self.kf.H = np.array([[1., 0, 0, 0, 0, 0, 0], [0, 1., 0, 0, 0, 0, 0],
[0, 0, 1., 0, 0, 0, 0], [0, 0, 0, 1., 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000.
# give high uncertainty to the unobservable initial velocities
......@@ -131,12 +119,13 @@ class KalmanBoxTracker(object):
self.velocity = None
self.delta_t = delta_t
def update(self, bbox):
def update(self, bbox, angle_cost=False):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
if self.last_observation.sum() >= 0: # no previous observation
if angle_cost and self.last_observation.sum(
) >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
......@@ -213,7 +202,8 @@ class OCSORTTracker(object):
inertia=0.2,
vertical_ratio=-1,
min_box_area=0,
use_byte=False):
use_byte=False,
use_angle_cost=False):
self.det_thresh = det_thresh
self.max_age = max_age
self.min_hits = min_hits
......@@ -223,6 +213,7 @@ class OCSORTTracker(object):
self.vertical_ratio = vertical_ratio
self.min_box_area = min_box_area
self.use_byte = use_byte
self.use_angle_cost = use_angle_cost
self.trackers = []
self.frame_count = 0
......@@ -270,23 +261,31 @@ class OCSORTTracker(object):
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
if self.use_angle_cost:
velocities = np.array([
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.trackers
])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.trackers
])
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
if self.use_angle_cost:
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations,
self.inertia)
else:
matched, unmatched_dets, unmatched_trks = associate_only_iou(
dets, trks, self.iou_threshold)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
self.trackers[m[1]].update(
dets[m[0], :], angle_cost=self.use_angle_cost)
"""
Second round of associaton by OCR
"""
......@@ -310,7 +309,8 @@ class OCSORTTracker(object):
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :])
self.trackers[trk_ind].update(
dets_second[det_ind, :], angle_cost=self.use_angle_cost)
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks,
np.array(to_remove_trk_indices))
......@@ -334,7 +334,8 @@ class OCSORTTracker(object):
1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :])
self.trackers[trk_ind].update(
dets[det_ind, :], angle_cost=self.use_angle_cost)
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets,
......@@ -349,6 +350,7 @@ class OCSORTTracker(object):
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册