未验证 提交 a86bf8b9 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] add python deploy for pptracking (#4507)

* add python deploy for pptracking

* fix plot tracking classes

* fix picodet deepsort

* fix format

* rename to det_infer
上级 7441fba7
# Python端预测部署
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
主要包含两个步骤:
- 导出预测模型
- 基于Python进行预测
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)
导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。
## 1. 对FairMOT模型的导出和预测
### 1.1 导出预测模型
```bash
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/fairmot/fairmot_hrnetv2_w18_dlafpn_30e_576x320.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_hrnetv2_w18_dlafpn_30e_576x320.pdparams
```
### 1.2 用导出的模型基于Python去预测
```bash
python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**注意:**
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
- 对于多类别或车辆的FairMOT模型的导出和Python预测只需更改相应的config和模型权重即可。如:
```
job_name=mcfairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone
model_type=mot/mcfairmot
config=configs/${model_type}/${job_name}.yml
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/mot/${job_name}.pdparams
python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/${job_name} --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
## 2. 对DeepSORT模型的导出和预测
### 2.1 导出预测模型
Step 1:导出检测模型
```bash
# 导出JDE YOLOv3行人检测模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/detector/jde_yolov3_darknet53_30e_1088x608_mix.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/jde_yolov3_darknet53_30e_1088x608_mix.pdparams
# 或导出PPYOLOv2行人检测模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/detector/ppyolov2_r50vd_dcn_365e_640x640_mot17half.yml -o weights=https://paddledet.bj.bcebos.com/mot/deepsort/ppyolov2_r50vd_dcn_365e_640x640_mot17half.pdparams
```
Step 2:导出ReID模型
```bash
# 导出PCB Pyramid ReID模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/reid/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pcb_pyramid_r101.pdparams
# 或者导出PPLCNet ReID模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/reid/deepsort_pplcnet.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_pplcnet.pdparams
```
### 2.2 用导出的模型基于Python去预测
```bash
# 用导出JDE YOLOv3行人检测模型和PCB Pyramid ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608_mix/ --reid_model_dir=output_inference/deepsort_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
# 或用导出的PPYOLOv2行人检测模型和PPLCNet ReID模型
python deploy/pptracking/python/mot_sde_infer.py --model_dir=output_inference/ppyolov2_r50vd_dcn_365e_640x640_mot17half/ --reid_model_dir=output_inference/deepsort_pplcnet/ --video_file={your video name}.mp4 --device=GPU --scaled=True --save_mot_txts
```
**注意:**
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`(对每个视频保存一个txt)或`--save_images`表示保存跟踪结果可视化图片。
- `--scaled`表示在模型输出结果的坐标是否已经是缩放回原图的,如果使用的检测模型是JDE的YOLOv3则为False,如果使用通用检测模型则为True。
## 参数说明:
| 参数 | 是否必须|含义 |
|-------|-------|----------|
| --model_dir | Yes| 上述导出的模型路径 |
| --image_file | Option | 需要预测的图片 |
| --image_dir | Option | 要预测的图片文件夹路径 |
| --video_file | Option | 需要预测的视频 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4|
| --device | Option | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --run_mode | Option |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | Option |预测时的batch size,在指定`image_dir`时有效,默认为1 |
| --threshold | Option|预测得分的阈值,默认为0.5|
| --output_dir | Option|可视化结果保存的根目录,默认为output/|
| --run_benchmark | Option| 是否运行benchmark,同时需指定`--image_file``--image_dir`,默认为False |
| --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False |
| --cpu_threads | Option| 设置cpu线程数,默认为1 |
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
说明:
- 参数优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`
- run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html)
- --run_benchmark如果设置为True,则需要安装依赖`pip install pynvml psutil GPUtil`
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import paddle
import paddle.inference as paddle_infer
from pathlib import Path
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
LOG_PATH_ROOT = f"{CUR_DIR}/../../output"
class PaddleInferBenchmark(object):
def __init__(self,
config,
model_info: dict={},
data_info: dict={},
perf_info: dict={},
resource_info: dict={},
**kwargs):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self.log_version = "1.0.3"
# Paddle Version
self.paddle_version = paddle.__version__
self.paddle_commit = paddle.__git_commit__
paddle_infer_info = paddle_infer.get_version()
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
# model info
self.model_info = model_info
# data info
self.data_info = data_info
# perf info
self.perf_info = perf_info
try:
# required value
self.model_name = model_info['model_name']
self.precision = model_info['precision']
self.batch_size = data_info['batch_size']
self.shape = data_info['shape']
self.data_num = data_info['data_num']
self.inference_time_s = round(perf_info['inference_time_s'], 4)
except:
self.print_help()
raise ValueError(
"Set argument wrong, please check input argument and its type")
self.preprocess_time_s = perf_info.get('preprocess_time_s', 0)
self.postprocess_time_s = perf_info.get('postprocess_time_s', 0)
self.total_time_s = perf_info.get('total_time_s', 0)
self.inference_time_s_90 = perf_info.get("inference_time_s_90", "")
self.inference_time_s_99 = perf_info.get("inference_time_s_99", "")
self.succ_rate = perf_info.get("succ_rate", "")
self.qps = perf_info.get("qps", "")
# conf info
self.config_status = self.parse_config(config)
# mem info
if isinstance(resource_info, dict):
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
self.cpu_vms_mb = int(resource_info.get('cpu_vms_mb', 0))
self.cpu_shared_mb = int(resource_info.get('cpu_shared_mb', 0))
self.cpu_dirty_mb = int(resource_info.get('cpu_dirty_mb', 0))
self.cpu_util = round(resource_info.get('cpu_util', 0), 2)
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
self.gpu_mem_util = round(resource_info.get('gpu_mem_util', 0), 2)
else:
self.cpu_rss_mb = 0
self.cpu_vms_mb = 0
self.cpu_shared_mb = 0
self.cpu_dirty_mb = 0
self.cpu_util = 0
self.gpu_rss_mb = 0
self.gpu_util = 0
self.gpu_mem_util = 0
# init benchmark logger
self.benchmark_logger()
def benchmark_logger(self):
"""
benchmark logger
"""
# remove other logging handler
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
# Init logger
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output = f"{LOG_PATH_ROOT}/{self.model_name}.log"
Path(f"{LOG_PATH_ROOT}").mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.FileHandler(
filename=log_output, mode='w'),
logging.StreamHandler(),
])
self.logger = logging.getLogger(__name__)
self.logger.info(
f"Paddle Inference benchmark log will be saved to {log_output}")
def parse_config(self, config) -> dict:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
if isinstance(config, paddle_infer.Config):
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu(
) else "cpu"
config_status['ir_optim'] = config.ir_optim()
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
config_status['precision'] = self.precision
config_status['enable_mkldnn'] = config.mkldnn_enabled()
config_status[
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
)
elif isinstance(config, dict):
config_status['runtime_device'] = config.get('runtime_device', "")
config_status['ir_optim'] = config.get('ir_optim', "")
config_status['enable_tensorrt'] = config.get('enable_tensorrt', "")
config_status['precision'] = config.get('precision', "")
config_status['enable_mkldnn'] = config.get('enable_mkldnn', "")
config_status['cpu_math_library_num_threads'] = config.get(
'cpu_math_library_num_threads', "")
else:
self.print_help()
raise ValueError(
"Set argument config wrong, please check input argument and its type"
)
return config_status
def report(self, identifier=None):
"""
print log report
args:
identifier(string): identify log
"""
if identifier:
identifier = f"[{identifier}]"
else:
identifier = ""
self.logger.info("\n")
self.logger.info(
"---------------------- Paddle info ----------------------")
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
self.logger.info(
"----------------------- Conf info -----------------------")
self.logger.info(
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
)
self.logger.info(
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
self.logger.info(f"{identifier} enable_memory_optim: {True}")
self.logger.info(
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
)
self.logger.info(
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
self.logger.info(
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
)
self.logger.info(
"----------------------- Model info ----------------------")
self.logger.info(f"{identifier} model_name: {self.model_name}")
self.logger.info(f"{identifier} precision: {self.precision}")
self.logger.info(
"----------------------- Data info -----------------------")
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
self.logger.info(f"{identifier} input_shape: {self.shape}")
self.logger.info(f"{identifier} data_num: {self.data_num}")
self.logger.info(
"----------------------- Perf info -----------------------")
self.logger.info(
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, cpu_vms: {self.cpu_vms_mb}, cpu_shared_mb: {self.cpu_shared_mb}, cpu_dirty_mb: {self.cpu_dirty_mb}, cpu_util: {self.cpu_util}%"
)
self.logger.info(
f"{identifier} gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%, gpu_mem_util: {self.gpu_mem_util}%"
)
self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}")
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
)
if self.inference_time_s_90:
self.looger.info(
f"{identifier} 90%_cost: {self.inference_time_s_90}, 99%_cost: {self.inference_time_s_99}, succ_rate: {self.succ_rate}"
)
if self.qps:
self.logger.info(f"{identifier} QPS: {self.qps}")
def print_help(self):
"""
print function help
"""
print("""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
""")
def __call__(self, identifier=None):
"""
__call__
args:
identifier(string): identify log
"""
self.report(identifier)
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import matching
from . import tracker
from . import motion
from . import utils
from .matching import *
from .tracker import *
from .motion import *
from .utils import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import jde_matching
from . import deepsort_matching
from .jde_matching import *
from .deepsort_matching import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/nwojke/deep_sort/tree/master/deep_sort
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
from ..motion import kalman_filter
INFTY_COST = 1e+5
__all__ = [
'iou_1toN',
'iou_cost',
'_nn_euclidean_distance',
'_nn_cosine_distance',
'NearestNeighborDistanceMetric',
'min_cost_matching',
'matching_cascade',
'gate_cost_matrix',
]
def iou_1toN(bbox, candidates):
"""
Computer intersection over union (IoU) by one box to N candidates.
Args:
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
same format as `bbox`.
Returns:
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
and each candidate. A higher score means a larger fraction of the
`bbox` is occluded by the candidate.
"""
bbox_tl = bbox[:2]
bbox_br = bbox[:2] + bbox[2:]
candidates_tl = candidates[:, :2]
candidates_br = candidates[:, :2] + candidates[:, 2:]
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
wh = np.maximum(0., br - tl)
area_intersection = wh.prod(axis=1)
area_bbox = bbox[2:].prod()
area_candidates = candidates[:, 2:].prod(axis=1)
ious = area_intersection / (area_bbox + area_candidates - area_intersection)
return ious
def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
"""
IoU distance metric.
Args:
tracks (list[Track]): A list of tracks.
detections (list[Detection]): A list of detections.
track_indices (Optional[list[int]]): A list of indices to tracks that
should be matched. Defaults to all `tracks`.
detection_indices (Optional[list[int]]): A list of indices to detections
that should be matched. Defaults to all `detections`.
Returns:
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
len(detection_indices) where entry (i, j) is
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
for row, track_idx in enumerate(track_indices):
if tracks[track_idx].time_since_update > 1:
cost_matrix[row, :] = 1e+5
continue
bbox = tracks[track_idx].to_tlwh()
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
return cost_matrix
def _nn_euclidean_distance(s, q):
"""
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s, q = np.asarray(s), np.asarray(q)
if len(s) == 0 or len(q) == 0:
return np.zeros((len(s), len(q)))
s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
distances = np.clip(distances, 0., float(np.inf))
return np.maximum(0.0, distances.min(axis=0))
def _nn_cosine_distance(s, q):
"""
Compute pair-wise cosine distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
distances = 1. - np.dot(s, q.T)
return distances.min(axis=0)
class NearestNeighborDistanceMetric(object):
"""
A nearest neighbor distance metric that, for each target, returns
the closest distance to any sample that has been observed so far.
Args:
metric (str): Either "euclidean" or "cosine".
matching_threshold (float): The matching threshold. Samples with larger
distance are considered an invalid match.
budget (Optional[int]): If not None, fix samples per class to at most
this number. Removes the oldest samples when the budget is reached.
Attributes:
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
identities to the list of samples that have been observed so far.
"""
def __init__(self, metric, matching_threshold, budget=None):
if metric == "euclidean":
self._metric = _nn_euclidean_distance
elif metric == "cosine":
self._metric = _nn_cosine_distance
else:
raise ValueError(
"Invalid metric; must be either 'euclidean' or 'cosine'")
self.matching_threshold = matching_threshold
self.budget = budget
self.samples = {}
def partial_fit(self, features, targets, active_targets):
"""
Update the distance metric with new data.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (ndarray): An integer array of associated target identities.
active_targets (List[int]): A list of targets that are currently
present in the scene.
"""
for feature, target in zip(features, targets):
self.samples.setdefault(target, []).append(feature)
if self.budget is not None:
self.samples[target] = self.samples[target][-self.budget:]
self.samples = {k: self.samples[k] for k in active_targets}
def distance(self, features, targets):
"""
Compute distance between features and targets.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (list[int]): A list of targets to match the given `features` against.
Returns:
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
where element (i, j) contains the closest squared distance between
`targets[i]` and `features[j]`.
"""
cost_matrix = np.zeros((len(targets), len(features)))
for i, target in enumerate(targets):
cost_matrix[i, :] = self._metric(self.samples[target], features)
return cost_matrix
def min_cost_matching(distance_metric,
max_distance,
tracks,
detections,
track_indices=None,
detection_indices=None):
"""
Solve linear assignment problem.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
if len(detection_indices) == 0 or len(track_indices) == 0:
return [], track_indices, detection_indices # Nothing to match.
cost_matrix = distance_metric(tracks, detections, track_indices,
detection_indices)
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
indices = linear_sum_assignment(cost_matrix)
matches, unmatched_tracks, unmatched_detections = [], [], []
for col, detection_idx in enumerate(detection_indices):
if col not in indices[1]:
unmatched_detections.append(detection_idx)
for row, track_idx in enumerate(track_indices):
if row not in indices[0]:
unmatched_tracks.append(track_idx)
for row, col in zip(indices[0], indices[1]):
track_idx = track_indices[row]
detection_idx = detection_indices[col]
if cost_matrix[row, col] > max_distance:
unmatched_tracks.append(track_idx)
unmatched_detections.append(detection_idx)
else:
matches.append((track_idx, detection_idx))
return matches, unmatched_tracks, unmatched_detections
def matching_cascade(distance_metric,
max_distance,
cascade_depth,
tracks,
detections,
track_indices=None,
detection_indices=None):
"""
Run matching cascade.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
cascade_depth (int): The cascade depth, should be se to the maximum
track age.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = list(range(len(tracks)))
if detection_indices is None:
detection_indices = list(range(len(detections)))
unmatched_detections = detection_indices
matches = []
for level in range(cascade_depth):
if len(unmatched_detections) == 0: # No detections left
break
track_indices_l = [
k for k in track_indices if tracks[k].time_since_update == 1 + level
]
if len(track_indices_l) == 0: # Nothing to match at this level
continue
matches_l, _, unmatched_detections = \
min_cost_matching(
distance_metric, max_distance, tracks, detections,
track_indices_l, unmatched_detections)
matches += matches_l
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
return matches, unmatched_tracks, unmatched_detections
def gate_cost_matrix(kf,
cost_matrix,
tracks,
detections,
track_indices,
detection_indices,
gated_cost=INFTY_COST,
only_position=False):
"""
Invalidate infeasible entries in cost matrix based on the state
distributions obtained by Kalman filtering.
Args:
kf (object): The Kalman filter.
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
number of track indices and M is the number of detection indices,
such that entry (i, j) is the association cost between
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (List[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
gated_cost (Optional[float]): Entries in the cost matrix corresponding
to infeasible associations are set this value. Defaults to a very
large value.
only_position (Optional[bool]): If True, only the x, y position of the
state distribution is considered during gating. Default False.
"""
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray(
[detections[i].to_xyah() for i in detection_indices])
for row, track_idx in enumerate(track_indices):
track = tracks[track_idx]
gating_distance = kf.gating_distance(track.mean, track.covariance,
measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
"""
import lap
import scipy
import numpy as np
from scipy.spatial.distance import cdist
from ..motion import kalman_filter
__all__ = [
'merge_matches',
'linear_assignment',
'cython_bbox_ious',
'iou_distance',
'embedding_distance',
'fuse_motion',
]
def merge_matches(m1, m2, shape):
O, P, Q = shape
m1 = np.asarray(m1)
m2 = np.asarray(m2)
M1 = scipy.sparse.coo_matrix(
(np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
M2 = scipy.sparse.coo_matrix(
(np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
mask = M1 * M2
match = mask.nonzero()
match = list(zip(match[0], match[1]))
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
return match, unmatched_O, unmatched_Q
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty(
(0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(
range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def cython_bbox_ious(atlbrs, btlbrs):
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
if ious.size == 0:
return ious
try:
import cython_bbox
except Exception as e:
print('cython_bbox not found, please install cython_bbox.'
'for example: `pip install cython_bbox`.')
exit()
ious = cython_bbox.bbox_overlaps(
np.ascontiguousarray(
atlbrs, dtype=np.float),
np.ascontiguousarray(
btlbrs, dtype=np.float))
return ious
def iou_distance(atracks, btracks):
"""
Compute cost based on IoU between two list[STrack].
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
_ious = cython_bbox_ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def embedding_distance(tracks, detections, metric='euclidean'):
"""
Compute cost based on features between two list[STrack].
"""
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray(
[track.curr_feat for track in detections], dtype=np.float)
track_features = np.asarray(
[track.smooth_feat for track in tracks], dtype=np.float)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features,
metric)) # Nomalized features
return cost_matrix
def fuse_motion(kf,
cost_matrix,
tracks,
detections,
only_position=False,
lambda_=0.98):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean,
track.covariance,
measurements,
only_position,
metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_
) * gating_distance
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kalman_filter
from .kalman_filter import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import numpy as np
import scipy.linalg
__all__ = ['KalmanFilter']
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {
1: 3.8415,
2: 5.9915,
3: 7.8147,
4: 9.4877,
5: 11.070,
6: 12.592,
7: 14.067,
8: 15.507,
9: 16.919
}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5,
10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-2, self._std_weight_position * mean[3]
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity *
mean[3], 1e-5, self._std_weight_velocity * mean[3]
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
#mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot(
(self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance,
self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position *
mean[:, 3], 1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]
]
std_vel = [
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity *
mean[:, 3], 1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(
projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot(
(kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self,
mean,
covariance,
measurements,
only_position=False,
metric='maha'):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(
cholesky_factor,
d.T,
lower=True,
check_finite=False,
overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import base_jde_tracker
from . import base_sde_tracker
from . import jde_tracker
from . import deepsort_tracker
from .base_jde_tracker import *
from .base_sde_tracker import *
from .jde_tracker import *
from .deepsort_tracker import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import defaultdict
from collections import deque, OrderedDict
from ..matching import jde_matching as matching
__all__ = [
'TrackState',
'BaseTrack',
'STrack',
'joint_stracks',
'sub_stracks',
'remove_duplicate_stracks',
]
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3
class BaseTrack(object):
_count_dict = defaultdict(int) # support single class and multi classes
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# multi-camera
location = (np.inf, np.inf)
@property
def end_frame(self):
return self.frame_id
@staticmethod
def next_id(cls_id):
BaseTrack._count_dict[cls_id] += 1
return BaseTrack._count_dict[cls_id]
# @even: reset track id
@staticmethod
def init_count(num_classes):
"""
Initiate _count for all object classes
:param num_classes:
"""
for cls_id in range(num_classes):
BaseTrack._count_dict[cls_id] = 0
@staticmethod
def reset_track_count(cls_id):
BaseTrack._count_dict[cls_id] = 0
def activate(self, *args):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def update(self, *args, **kwargs):
raise NotImplementedError
def mark_lost(self):
self.state = TrackState.Lost
def mark_removed(self):
self.state = TrackState.Removed
class STrack(BaseTrack):
def __init__(self,
tlwh,
score,
temp_feat,
num_classes,
cls_id,
buff_size=30):
# object class id
self.cls_id = cls_id
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.score = score
self.track_len = 0
self.smooth_feat = None
self.update_features(temp_feat)
self.features = deque([], maxlen=buff_size)
self.alpha = 0.9
def update_features(self, feat):
# L2 normalizing
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha
) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state,
self.covariance)
@staticmethod
def multi_predict(tracks, kalman_filter):
if len(tracks) > 0:
multi_mean = np.asarray([track.mean.copy() for track in tracks])
multi_covariance = np.asarray(
[track.covariance for track in tracks])
for i, st in enumerate(tracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = kalman_filter.multi_predict(
multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
tracks[i].mean = mean
tracks[i].covariance = cov
def reset_track_id(self):
self.reset_track_count(self.cls_id)
def activate(self, kalman_filter, frame_id):
"""Start a new track"""
self.kalman_filter = kalman_filter
# update track id for the object class
self.track_id = self.next_id(self.cls_id)
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh))
self.track_len = 0
self.state = TrackState.Tracked # set flag 'tracked'
if frame_id == 1: # to record the first frame's detection result
self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat)
self.track_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id: # update track id for the object class
self.track_id = self.next_id(self.cls_id)
def update(self, new_track, frame_id, update_feature=True):
self.frame_id = frame_id
self.track_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked # set flag 'tracked'
self.is_activated = True # set flag 'activated'
self.score = new_track.score
if update_feature:
self.update_features(new_track.curr_feat)
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
@property
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)
@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
self.start_frame, self.end_frame)
def joint_stracks(tlista, tlistb):
exists = {}
res = []
for t in tlista:
exists[t.track_id] = 1
res.append(t)
for t in tlistb:
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res
def sub_stracks(tlista, tlistb):
stracks = {}
for t in tlista:
stracks[t.track_id] = t
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
def remove_duplicate_stracks(stracksa, stracksb):
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame
timeq = stracksb[q].frame_id - stracksb[q].start_frame
if timep > timeq:
dupb.append(q)
else:
dupa.append(p)
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
return resa, resb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""
import datetime
__all__ = ['TrackState', 'Track']
class TrackState(object):
"""
Enumeration type for the single target track state. Newly created tracks are
classified as `tentative` until enough evidence has been collected. Then,
the track state is changed to `confirmed`. Tracks that are no longer alive
are classified as `deleted` to mark them for removal from the set of active
tracks.
"""
Tentative = 1
Confirmed = 2
Deleted = 3
class Track(object):
"""
A single target track with state space `(x, y, a, h)` and associated
velocities, where `(x, y)` is the center of the bounding box, `a` is the
aspect ratio and `h` is the height.
Args:
mean (ndarray): Mean vector of the initial state distribution.
covariance (ndarray): Covariance matrix of the initial state distribution.
track_id (int): A unique track identifier.
n_init (int): Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
max_age (int): The maximum number of consecutive misses before the track
state is set to `Deleted`.
cls_id (int): The category id of the tracked box.
score (float): The confidence score of the tracked box.
feature (Optional[ndarray]): Feature vector of the detection this track
originates from. If not None, this feature is added to the `features` cache.
Attributes:
hits (int): Total number of measurement updates.
age (int): Total number of frames since first occurance.
time_since_update (int): Total number of frames since last measurement
update.
state (TrackState): The current track state.
features (List[ndarray]): A cache of features. On each measurement update,
the associated feature vector is added to this list.
"""
def __init__(self,
mean,
covariance,
track_id,
n_init,
max_age,
cls_id,
score,
feature=None):
self.mean = mean
self.covariance = covariance
self.track_id = track_id
self.hits = 1
self.age = 1
self.time_since_update = 0
self.cls_id = cls_id
self.score = score
self.start_time = datetime.datetime.now()
self.state = TrackState.Tentative
self.features = []
if feature is not None:
self.features.append(feature)
self._n_init = n_init
self._max_age = max_age
def to_tlwh(self):
"""Get position in format `(top left x, top left y, width, height)`."""
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
def to_tlbr(self):
"""Get position in bounding box format `(min x, miny, max x, max y)`."""
ret = self.to_tlwh()
ret[2:] = ret[:2] + ret[2:]
return ret
def predict(self, kalman_filter):
"""
Propagate the state distribution to the current time step using a Kalman
filter prediction step.
"""
self.mean, self.covariance = kalman_filter.predict(self.mean,
self.covariance)
self.age += 1
self.time_since_update += 1
def update(self, kalman_filter, detection):
"""
Perform Kalman filter measurement update step and update the associated
detection feature cache.
"""
self.mean, self.covariance = kalman_filter.update(self.mean,
self.covariance,
detection.to_xyah())
self.features.append(detection.feature)
self.cls_id = detection.cls_id
self.score = detection.score
self.hits += 1
self.time_since_update = 0
if self.state == TrackState.Tentative and self.hits >= self._n_init:
self.state = TrackState.Confirmed
def mark_missed(self):
"""Mark this track as missed (no association at the current time step).
"""
if self.state == TrackState.Tentative:
self.state = TrackState.Deleted
elif self.time_since_update > self._max_age:
self.state = TrackState.Deleted
def is_tentative(self):
"""Returns True if this track is tentative (unconfirmed)."""
return self.state == TrackState.Tentative
def is_confirmed(self):
"""Returns True if this track is confirmed."""
return self.state == TrackState.Confirmed
def is_deleted(self):
"""Returns True if this track is dead and should be deleted."""
return self.state == TrackState.Deleted
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
"""
import numpy as np
from ..motion import KalmanFilter
from ..matching.deepsort_matching import NearestNeighborDistanceMetric
from ..matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
from .base_sde_tracker import Track
from ..utils import Detection
__all__ = ['DeepSORTTracker']
class DeepSORTTracker(object):
"""
DeepSORT tracker
Args:
input_size (list): input feature map size to reid model, [h, w] format,
[64, 192] as default.
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set <=0
means no need to filter bboxes.
budget (int): If not None, fix samples per class to at most this number.
Removes the oldest samples when the budget is reached.
max_age (int): maximum number of missed misses before a track is deleted
n_init (float): Number of frames that a track remains in initialization
phase. Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
matching_threshold (float): samples with larger distance are
considered an invalid match.
max_iou_distance (float): max iou distance threshold
motion (object): KalmanFilter instance
"""
def __init__(self,
input_size=[64, 192],
min_box_area=0,
vertical_ratio=-1,
budget=100,
max_age=70,
n_init=3,
metric_type='cosine',
matching_threshold=0.2,
max_iou_distance=0.9,
motion='KalmanFilter'):
self.input_size = input_size
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.max_age = max_age
self.n_init = n_init
self.metric = NearestNeighborDistanceMetric(metric_type,
matching_threshold, budget)
self.max_iou_distance = max_iou_distance
if motion == 'KalmanFilter':
self.motion = KalmanFilter()
self.tracks = []
self._next_id = 1
def predict(self):
"""
Propagate track state distributions one time step forward.
This function should be called once every time step, before `update`.
"""
for track in self.tracks:
track.predict(self.motion)
def update(self, pred_dets, pred_embs):
"""
Perform measurement update and track management.
Args:
pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'.
pred_embs (np.array): Embedding results of the image, the shape is
[N, 128], usually pred_embs.shape[1] is a multiple of 128.
"""
pred_tlwhs = pred_dets[:, :4]
pred_scores = pred_dets[:, 4:5]
pred_cls_ids = pred_dets[:, 5:]
detections = [
Detection(tlwh, score, feat, cls_id)
for tlwh, score, feat, cls_id in zip(pred_tlwhs, pred_scores,
pred_embs, pred_cls_ids)
]
# Run matching cascade.
matches, unmatched_tracks, unmatched_detections = \
self._match(detections)
# Update track set.
for track_idx, detection_idx in matches:
self.tracks[track_idx].update(self.motion,
detections[detection_idx])
for track_idx in unmatched_tracks:
self.tracks[track_idx].mark_missed()
for detection_idx in unmatched_detections:
self._initiate_track(detections[detection_idx])
self.tracks = [t for t in self.tracks if not t.is_deleted()]
# Update distance metric.
active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
features, targets = [], []
for track in self.tracks:
if not track.is_confirmed():
continue
features += track.features
targets += [track.track_id for _ in track.features]
track.features = []
self.metric.partial_fit(
np.asarray(features), np.asarray(targets), active_targets)
output_stracks = self.tracks
return output_stracks
def _match(self, detections):
def gated_metric(tracks, dets, track_indices, detection_indices):
features = np.array([dets[i].feature for i in detection_indices])
targets = np.array([tracks[i].track_id for i in track_indices])
cost_matrix = self.metric.distance(features, targets)
cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks,
dets, track_indices,
detection_indices)
return cost_matrix
# Split track set into confirmed and unconfirmed tracks.
confirmed_tracks = [
i for i, t in enumerate(self.tracks) if t.is_confirmed()
]
unconfirmed_tracks = [
i for i, t in enumerate(self.tracks) if not t.is_confirmed()
]
# Associate confirmed tracks using appearance features.
matches_a, unmatched_tracks_a, unmatched_detections = \
matching_cascade(
gated_metric, self.metric.matching_threshold, self.max_age,
self.tracks, detections, confirmed_tracks)
# Associate remaining tracks together with unconfirmed tracks using IOU.
iou_track_candidates = unconfirmed_tracks + [
k for k in unmatched_tracks_a
if self.tracks[k].time_since_update == 1
]
unmatched_tracks_a = [
k for k in unmatched_tracks_a
if self.tracks[k].time_since_update != 1
]
matches_b, unmatched_tracks_b, unmatched_detections = \
min_cost_matching(
iou_cost, self.max_iou_distance, self.tracks,
detections, iou_track_candidates, unmatched_detections)
matches = matches_a + matches_b
unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
return matches, unmatched_tracks, unmatched_detections
def _initiate_track(self, detection):
mean, covariance = self.motion.initiate(detection.to_xyah())
self.tracks.append(
Track(mean, covariance, self._next_id, self.n_init, self.max_age,
detection.cls_id, detection.score, detection.feature))
self._next_id += 1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import defaultdict
from ..matching import jde_matching as matching
from ..motion import KalmanFilter
from .base_jde_tracker import TrackState, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
__all__ = ['JDETracker']
class JDETracker(object):
__shared__ = ['num_classes']
"""
JDE tracker, support single class and multi classes
Args:
num_classes (int): the number of classes
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (str): motion model, KalmanFilter as default
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
num_classes=1,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.num_classes = num_classes
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
if motion == 'KalmanFilter':
self.motion = KalmanFilter()
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_tracks_dict = defaultdict(list) # dict(list[STrack])
self.lost_tracks_dict = defaultdict(list) # dict(list[STrack])
self.removed_tracks_dict = defaultdict(list) # dict(list[STrack])
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'.
pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512].
Return:
output_stracks_dict (dict(list)): The list contains information
regarding the online_tracklets for the recieved image tensor.
"""
self.frame_id += 1
if self.frame_id == 1:
STrack.init_count(self.num_classes)
activated_tracks_dict = defaultdict(list)
refined_tracks_dict = defaultdict(list)
lost_tracks_dict = defaultdict(list)
removed_tracks_dict = defaultdict(list)
output_tracks_dict = defaultdict(list)
pred_dets_dict = defaultdict(list)
pred_embs_dict = defaultdict(list)
# unify single and multi classes detection and embedding results
for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx]
pred_embs_dict[cls_id] = pred_embs[cls_idx]
for cls_id in range(self.num_classes):
""" Step 1: Get detections by class"""
pred_dets_cls = pred_dets_dict[cls_id]
pred_embs_cls = pred_embs_dict[cls_id]
remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds]
pred_embs_cls = pred_embs_cls[remain_inds]
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f,
self.num_classes, cls_id, 30)
for (tlbrs, f) in zip(pred_dets_cls, pred_embs_cls)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed_dict = defaultdict(list)
tracked_tracks_dict = defaultdict(list)
for track in self.tracked_tracks_dict[cls_id]:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed_dict[cls_id].append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_tracks_dict[cls_id].append(track)
""" Step 2: First association, with embedding"""
# building tracking pool for the current frame
track_pool_dict = defaultdict(list)
track_pool_dict[cls_id] = joint_stracks(
tracked_tracks_dict[cls_id], self.lost_tracks_dict[cls_id])
# Predict the current location with KalmanFilter
STrack.multi_predict(track_pool_dict[cls_id], self.motion)
dists = matching.embedding_distance(
track_pool_dict[cls_id], detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists,
track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
for i_tracked, idet in matches:
# i_tracked is the id of the track and idet is the detection
track = track_pool_dict[cls_id][i_tracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_tracks_dict[cls_id].append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refined_tracks_dict[cls_id].append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection]
r_tracked_stracks = []
for i in u_track:
if track_pool_dict[cls_id][i].state == TrackState.Tracked:
r_tracked_stracks.append(track_pool_dict[cls_id][i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refined_tracks_dict[cls_id].append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_tracks_dict[cls_id].append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed_dict[cls_id], detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=self.unconfirmed_thresh)
for i_tracked, idet in matches:
unconfirmed_dict[cls_id][i_tracked].update(detections[idet],
self.frame_id)
activated_tracks_dict[cls_id].append(unconfirmed_dict[cls_id][
i_tracked])
for it in u_unconfirmed:
track = unconfirmed_dict[cls_id][it]
track.mark_removed()
removed_tracks_dict[cls_id].append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_tracks_dict[cls_id].append(track)
""" Step 5: Update state"""
for track in self.lost_tracks_dict[cls_id]:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_tracks_dict[cls_id].append(track)
self.tracked_tracks_dict[cls_id] = [
t for t in self.tracked_tracks_dict[cls_id]
if t.state == TrackState.Tracked
]
self.tracked_tracks_dict[cls_id] = joint_stracks(
self.tracked_tracks_dict[cls_id], activated_tracks_dict[cls_id])
self.tracked_tracks_dict[cls_id] = joint_stracks(
self.tracked_tracks_dict[cls_id], refined_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id] = sub_stracks(
self.lost_tracks_dict[cls_id], self.tracked_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id].extend(lost_tracks_dict[cls_id])
self.lost_tracks_dict[cls_id] = sub_stracks(
self.lost_tracks_dict[cls_id], self.removed_tracks_dict[cls_id])
self.removed_tracks_dict[cls_id].extend(removed_tracks_dict[cls_id])
self.tracked_tracks_dict[cls_id], self.lost_tracks_dict[
cls_id] = remove_duplicate_stracks(
self.tracked_tracks_dict[cls_id],
self.lost_tracks_dict[cls_id])
# get scores of lost tracks
output_tracks_dict[cls_id] = [
track for track in self.tracked_tracks_dict[cls_id]
if track.is_activated
]
return output_tracks_dict
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import paddle
import numpy as np
__all__ = [
'MOTTimer',
'Detection',
'write_mot_results',
'load_det_results',
'preprocess_reid',
'get_crops',
'clip_box',
'scale_coords',
]
class MOTTimer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
class Detection(object):
"""
This class represents a bounding box detection in a single image.
Args:
tlwh (Tensor): Bounding box in format `(top left x, top left y,
width, height)`.
score (Tensor): Bounding box confidence score.
feature (Tensor): A feature vector that describes the object
contained in this image.
cls_id (Tensor): Bounding box category id.
"""
def __init__(self, tlwh, score, feature, cls_id):
self.tlwh = np.asarray(tlwh, dtype=np.float32)
self.score = float(score)
self.feature = np.asarray(feature, dtype=np.float32)
self.cls_id = int(cls_id)
def to_tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
def to_xyah(self):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = self.tlwh.copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def write_mot_results(filename, results, data_type='mot', num_classes=1):
# support single and multi classes
if data_type in ['mot', 'mcmot']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
f = open(filename, 'w')
for cls_id in range(num_classes):
for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue
if data_type == 'kitti':
frame_id -= 1
elif data_type == 'mot':
cls_id = -1
elif data_type == 'mcmot':
cls_id = cls_id
x1, y1, w, h = tlwh
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
w=w,
h=h,
score=score,
cls_id=cls_id)
f.write(line)
print('MOT results save in {}'.format(filename))
def load_det_results(det_file, num_frames):
assert os.path.exists(det_file) and os.path.isfile(det_file), \
'{} is not exist or not a file.'.format(det_file)
labels = np.loadtxt(det_file, dtype='float32', delimiter=',')
assert labels.shape[1] == 7, \
"Each line of {} should have 7 items: '[frame_id],[x0],[y0],[w],[h],[score],[class_id]'.".format(det_file)
results_list = []
for frame_i in range(num_frames):
results = {'bbox': [], 'score': [], 'cls_id': []}
lables_with_frame = labels[labels[:, 0] == frame_i + 1]
# each line of lables_with_frame:
# [frame_id],[x0],[y0],[w],[h],[score],[class_id]
for l in lables_with_frame:
results['bbox'].append(l[1:5])
results['score'].append(l[5])
results['cls_id'].append(l[6])
results_list.append(results)
return results_list
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords = paddle.cast(coords, 'float32')
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor.numpy()[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1])
xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0])
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = paddle.logical_and(h > 0, w > 0)
keep_idx = paddle.nonzero(mask)
xyxy = paddle.gather_nd(xyxy, keep_idx[:, :1])
return xyxy, keep_idx
def get_crops(xyxy, ori_img, w, h):
crops = []
xyxy = xyxy.numpy().astype(np.int64)
ori_img = ori_img.numpy()
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2)
for i, bbox in enumerate(xyxy):
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
crops = preprocess_reid(crops, w, h)
return crops
def preprocess_reid(imgs,
w=64,
h=192,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from det_infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark
from visualize import plot_tracking_dict
from mot.tracker import JDETracker
from mot.utils import MOTTimer, write_mot_results
# Global dictionary
MOT_SUPPORT_MODELS = {
'JDE',
'FairMOT',
}
class JDE_Detector(Detector):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
super(JDE_Detector, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
assert pred_config.tracker, "Tracking model should have tracker"
self.num_classes = len(pred_config.labels)
tp = pred_config.tracker
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker(
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
conf_thres=conf_thres,
tracked_thresh=tracked_thresh,
metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
Returns:
online_tlwhs, online_scores, online_ids (dict[np.array])
'''
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
self.det_times.preprocess_time_s.end()
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
pred_dets = boxes_tensor.copy_to_cpu()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
pred_dets = boxes_tensor.copy_to_cpu()
embs_tensor = self.predictor.get_output_handle(output_names[1])
pred_embs = embs_tensor.copy_to_cpu()
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
online_tlwhs, online_scores, online_ids = self.postprocess(
pred_dets, pred_embs, threshold)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
def predict_image(detector, image_list):
results = []
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
image_list.sort()
for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
detector.predict([frame], FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(frame_id, img_file))
else:
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
online_im = plot_tracking_dict(frame, num_classes, online_tlwhs,
online_ids, online_scores, frame_id,
ids2names)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
def predict_video(detector, camera_id):
video_name = 'mot_output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
while (1):
ret, frame = capture.read()
if not ret:
break
timer.tic()
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
timer.toc()
for cls_id in range(num_classes):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
fps = 1. / timer.average_time
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
frame_id += 1
print('detect frame: %d' % (frame_id))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = JDE_Detector(
pred_config,
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': 1,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('MOT')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
main()
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.special import softmax
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PicoDetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
input_shape,
ori_shape,
scale_factor,
strides=[8, 16, 32, 64],
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=1000,
keep_top_k=100):
self.ori_shape = ori_shape
self.input_shape = input_shape
self.scale_factor = scale_factor
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes
"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes
def __call__(self, scores, raw_boxes):
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
out_boxes_list = []
for batch_id in range(batch_size):
# generate centers
decode_boxes = []
select_scores = []
for stride, box_distribute, score in zip(self.strides, raw_boxes,
scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
fm_h = self.input_shape[0] / stride
fm_w = self.input_shape[1] / stride
h_range = np.arange(fm_h)
w_range = np.arange(fm_w)
ww, hh = np.meshgrid(w_range, h_range)
ct_row = (hh.flatten() + 0.5) * stride
ct_col = (ww.flatten() + 0.5) * stride
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
# box distribution to distance
reg_range = np.arange(reg_max + 1)
box_distance = box_distribute.reshape((-1, reg_max + 1))
box_distance = softmax(box_distance, axis=1)
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
box_distance = box_distance * stride
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
topk_idx = topk_idx[:self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
# decode box
decode_box = center + [-1, -1, 1, 1] * box_distance
select_scores.append(score)
decode_boxes.append(decode_box)
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.keep_top_k, )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
picked_box_probs[:, :4], self.ori_shape[batch_id])
im_scale = np.concatenate([
self.scale_factor[batch_id][::-1],
self.scale_factor[batch_id][::-1]
])
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
np.expand_dims(
np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1))
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
return out_boxes_list, out_boxes_num
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def __init__(self, mean, std, is_scale=True):
self.mean = mean
self.std = std
self.is_scale = is_scale
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import ast
import argparse
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."),
required=True)
parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch_size for inference.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
action='store_true',
help='Save visualization image results.')
parser.add_argument(
'--save_mot_txts',
action='store_true',
help='Save tracking results (txt).')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--reid_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--reid_batch_size",
type=int,
default=50,
help="max batch_size for reid model inference.")
return parser
class Times(object):
def __init__(self):
self.time = 0.
# start time
self.st = 0.
# end time
self.et = 0.
def start(self):
self.st = time.time()
def end(self, repeats=1, accumulative=True):
self.et = time.time()
if accumulative:
self.time += (self.et - self.st) / repeats
else:
self.time = (self.et - self.st) / repeats
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self):
super(Timer, self).__init__()
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.img_num = 0
def info(self, average=False):
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
postprocess_time = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
inference_time = round(self.inference_time_s.value() /
max(1, self.img_num),
4) if average else self.inference_time_s.value()
average_latency = total_time / max(1, self.img_num)
qps = 0
if total_time > 0:
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
dic['preprocess_time_s'] = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
dic['postprocess_time_s'] = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
dic['inference_time_s'] = round(
self.inference_time_s.value() / max(1, self.img_num),
4) if average else self.inference_time_s.value()
dic['img_num'] = self.img_num
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
dic['total_time_s'] = round(total_time, 4)
return dic
def get_current_memory_mb():
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import pynvml
import psutil
import GPUtil
gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0))
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
gpus = GPUtil.getGPUs()
if gpu_id is not None and len(gpus) > 0:
gpu_percent = gpus[gpu_id].load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
# coding: utf-8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
import math
def visualize_box_mask(im, results, labels, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
"""
if isinstance(im, str):
im = Image.open(im).convert('RGB')
else:
im = Image.fromarray(im)
if 'boxes' in results and len(results['boxes']) > 0:
im = draw_box(im, results['boxes'], labels, threshold=threshold)
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def draw_box(im, np_boxes, labels, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of box
Returns:
im (PIL.Image.Image): visualized image
"""
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
clsid2color = {}
color_list = get_color_map_list(len(labels))
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color = tuple(clsid2color[clsid])
if len(bbox) == 4:
xmin, ymin, xmax, ymax = bbox
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'.format(
int(clsid), score, xmin, ymin, xmax, ymax))
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
elif len(bbox) == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill=color)
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
# draw label
text = "{} {:.4f}".format(labels[clsid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def plot_tracking(image,
tlwhs,
obj_ids,
scores=None,
frame_id=0,
fps=0.,
ids2names=[]):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2names != []:
assert len(ids2names) == 1, "plot_tracking only supports single classes."
id_text = '{}_'.format(ids2names[0]) + id_text
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
def plot_tracking_dict(image,
num_classes,
tlwhs_dict,
obj_ids_dict,
scores_dict,
frame_id=0,
fps=0.,
ids2names=[]):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
for cls_id in range(num_classes):
tlwhs = tlwhs_dict[cls_id]
obj_ids = obj_ids_dict[cls_id]
scores = scores_dict[cls_id]
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2names != []:
id_text = '{}_{}'.format(ids2names[cls_id], id_text)
else:
id_text = 'class{}_{}'.format(cls_id, id_text)
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(
im,
intbox[0:2],
intbox[2:4],
color=color,
thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册