未验证 提交 822c7933 编写于 作者: F Feng Ni 提交者: GitHub

[cherry-pick][MOT] fix deploy python mot infer (#5630)

* fix deploy python mot infer, fix cfgs

* fix doc, test=document_fix
上级 20b2f101
......@@ -53,4 +53,4 @@ JDETracker:
det_thresh: 0.3
track_buffer: 30
min_box_area: 200
motion: KalmanFilter
vertical_ratio: 1.6 # for pedestrian
......@@ -11,8 +11,8 @@ JDETracker:
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6 # for pedestrian
min_box_area: 0
vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker:
input_size: [64, 192]
......
......@@ -38,7 +38,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set
bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
......@@ -64,8 +64,8 @@ class JDETracker(object):
num_classes=1,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
min_box_area=0,
vertical_ratio=0,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
......@@ -161,9 +161,8 @@ class JDETracker(object):
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
30, temp_feat)
for (tlbrs, temp_feat
) in zip(pred_dets_cls, pred_embs_cls)
30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls, pred_embs_cls)
]
else:
detections = []
......@@ -238,15 +237,13 @@ class JDETracker(object):
for tlbrs in pred_dets_cls_second
]
else:
pred_embs_cls_second = pred_embs_dict[cls_id][inds_second]
pred_embs_cls_second = pred_embs_dict[cls_id][
inds_second]
detections_second = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]),
tlbrs[1],
cls_id,
30,
temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls_second, pred_embs_cls_second)
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls_second, pred_embs_cls_second)
]
else:
detections_second = []
......
......@@ -112,8 +112,8 @@ class JDE_Detector(Detector):
# tracker config
assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
cfg = self.pred_config.tracker
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
min_box_area = cfg.get('min_box_area', 0.0)
vertical_ratio = cfg.get('vertical_ratio', 0.0)
conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean')
......@@ -164,7 +164,7 @@ class JDE_Detector(Detector):
repeats (int): repeats number for prediction
Returns:
result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[x_min, y_min, x_max, y_max, score, class]
matix element:[class, score, x_min, y_min, x_max, y_max]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128]
'''
......
......@@ -165,8 +165,8 @@ class SDE_Detector(Detector):
# use ByteTracker
use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1)
......@@ -194,7 +194,7 @@ class SDE_Detector(Detector):
return result
def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes']
pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image']
......@@ -234,7 +234,7 @@ class SDE_Detector(Detector):
return det_results
def tracking(self, det_results):
pred_dets = det_results['boxes']
pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results.get('embeddings', None)
if self.use_deepsort_tracker:
......
......@@ -11,8 +11,8 @@ JDETracker:
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6 # for pedestrian
min_box_area: 0
vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker:
input_size: [64, 192]
......
......@@ -3,27 +3,76 @@
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
主要包含两个步骤:
Python端预测部署主要包含两个步骤:
- 导出预测模型
- 基于Python进行预测
## 1. 导出预测模型
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](../deploy/EXPORT_MODEL.md),例如
```bash
# 导出YOLOv3检测模型
python tools/export_model.py -c configs/yolov3/yolov3_darknet53_270e_coco.yml --output_dir=./inference_model \
-o weights=https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams
# 导出HigherHRNet(bottom-up)关键点检测模型
python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams
# 导出HRNet(top-down)关键点检测模型
python tools/export_model.py -c configs/keypoint/hrnet/hrnet_w32_384x288.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams
# 导出FairMOT多目标跟踪模型
python tools/export_model.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams
# 导出ByteTrack多目标跟踪模型(相当于只导出检测器)
python tools/export_model.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
```
导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。
## 2. 基于Python的预测
## 2. 基于Python的预测
### 2.1 通用检测
在终端输入以下命令进行预测:
```bash
python deploy/python/infer.py --model_dir=./output_inference/yolov3_darknet53_270e_coco --image_file=./demo/000000014439.jpg --device=GPU
```
### 2.2 关键点检测
在终端输入以下命令进行预测:
```bash
# keypoint top-down(HRNet)/bottom-up(HigherHRNet)单独推理,该模式下top-down模型HRNet只支持单人截图预测
python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=./demo/hrnet_demo.jpg --device=GPU --threshold=0.5
python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=./demo/000000014439_640x640.jpg --device=GPU --threshold=0.5
# detector 检测 + keypoint top-down模型联合部署(联合推理只支持top-down关键点模型)
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/yolov3_darknet53_270e_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file={your video name}.mp4 --device=GPU
```
**注意:**
- 关键点检测模型导出和预测具体可参照[keypoint](../../configs/keypoint/README.md),可分别在各个模型的文档中查找具体用法;
- 此目录下的关键点检测部署为基础前向功能,更多关键点检测功能可使用PP-Human项目,参照[pphuman](../pphuman/README.md)
### 2.3 多目标跟踪
在终端输入以下命令进行预测:
```bash
python deploy/python/infer.py --model_dir=./output_inference/yolov3_mobilenet_v1_roadsign --image_file=./demo/road554.png --device=GPU
# FairMOT跟踪
python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608 --video_file={your video name}.mp4 --device=GPU
# ByteTrack跟踪
python deploy/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --tracker_config=deploy/python/tracker_config.yml --video_file={your video name}.mp4 --device=GPU --scaled=True
# FairMOT多目标跟踪联合HRNet关键点检测(联合推理只支持top-down关键点模型)
python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inference/fairmot_dla34_30e_1088x608/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file={your video name}.mp4 --device=GPU
```
**注意:**
- 多目标跟踪模型导出和预测具体可参照[mot]](../../configs/mot/README.md),可分别在各个模型的文档中查找具体用法;
- 此目录下的跟踪部署为基础前向功能以及联合关键点部署,更多跟踪功能可使用PP-Human项目,参照[pphuman](../pphuman/README.md),或PP-Tracking项目(绘制轨迹、出入口流量计数),参照[pptracking](../pptracking/README.md)
参数说明如下:
| 参数 | 是否必须|含义 |
......
......@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.visualize import plot_tracking_dict
# Global dictionary
MOT_JDE_SUPPORT_MODELS = {
......@@ -55,9 +55,14 @@ class JDE_Detector(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
"""
def __init__(self,
def __init__(
self,
model_dir,
tracker_config=None,
device='CPU',
......@@ -70,7 +75,9 @@ class JDE_Detector(Detector):
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5):
threshold=0.5,
save_images=False,
save_mot_txts=False, ):
super(JDE_Detector, self).__init__(
model_dir=model_dir,
device=device,
......@@ -84,6 +91,8 @@ class JDE_Detector(Detector):
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
......@@ -91,8 +100,8 @@ class JDE_Detector(Detector):
# tracker config
assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
cfg = self.pred_config.tracker
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
min_box_area = cfg.get('min_box_area', 0.0)
vertical_ratio = cfg.get('vertical_ratio', 0.0)
conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean')
......@@ -115,7 +124,7 @@ class JDE_Detector(Detector):
return result
def tracking(self, det_results):
pred_dets = det_results['pred_dets'] # 'cls_id, score, x0, y0, x1, y1'
pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
......@@ -164,7 +173,8 @@ class JDE_Detector(Detector):
image_list,
run_benchmark=False,
repeats=1,
visual=True):
visual=True,
seq_name=None):
mot_results = []
num_classes = self.num_classes
image_list.sort()
......@@ -225,7 +235,7 @@ class JDE_Detector(Detector):
self.det_times.img_num += 1
if visual:
if frame_id % 10 == 0:
if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
......@@ -237,6 +247,7 @@ class JDE_Detector(Detector):
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
......@@ -264,7 +275,8 @@ class JDE_Detector(Detector):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1
......@@ -282,7 +294,9 @@ class JDE_Detector(Detector):
frame_id += 1
timer.tic()
mot_results = self.predict_image([frame], visual=False)
seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0]
......@@ -307,20 +321,33 @@ class JDE_Detector(Detector):
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
writer.release()
def main():
detector = JDE_Detector(
FLAGS.model_dir,
tracker_config=None,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
......@@ -24,7 +24,7 @@ from collections import defaultdict
from mot_keypoint_unite_utils import argsparser
from preprocess import decode_image
from infer import print_arguments, get_test_images
from infer import print_arguments, get_test_images, bench_log
from mot_sde_infer import SDE_Detector
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
......@@ -39,7 +39,7 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.utils import MOTTimer as FPSTimer
......@@ -92,7 +92,7 @@ def mot_topdown_unite_predict(mot_detector,
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
FLAGS.run_benchmark)
if save_res:
store_res.append([
......@@ -146,7 +146,7 @@ def mot_topdown_unite_predict_video(mot_detector,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
......@@ -179,7 +179,7 @@ def mot_topdown_unite_predict_video(mot_detector,
timer_kp.tic()
keypoint_res = predict_with_given_det(
frame, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
FLAGS.run_benchmark)
timer_kp.toc()
timer_mot_kp.toc()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -23,15 +23,15 @@ import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
......@@ -50,7 +50,11 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
"""
def __init__(self,
......@@ -66,7 +70,10 @@ class SDE_Detector(Detector):
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5):
threshold=0.5,
save_images=False,
save_mot_txts=False,
reid_model_dir=None):
super(SDE_Detector, self).__init__(
model_dir=model_dir,
device=device,
......@@ -80,37 +87,163 @@ class SDE_Detector(Detector):
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
# tracker config
# reid config
self.use_reid = False if reid_model_dir is None else True
if self.use_reid:
self.reid_pred_config = self.set_config(reid_model_dir)
self.reid_predictor, self.config = load_predictor(
reid_model_dir,
run_mode=run_mode,
batch_size=50, # reid_batch_size
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
else:
self.reid_pred_config = None
self.reid_predictor = None
assert tracker_config is not None, 'Note that tracker_config should be set.'
self.tracker_config = tracker_config
cfg = yaml.safe_load(open(self.tracker_config))['tracker']
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
use_byte = cfg.get('use_byte', True)
tracker_cfg = yaml.safe_load(open(self.tracker_config))
cfg = tracker_cfg[tracker_cfg['type']]
# tracker config
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
budget = cfg.get('budget', 100)
max_age = cfg.get('max_age', 30)
max_iou_distance = cfg.get('max_iou_distance', 0.7)
matching_threshold = cfg.get('matching_threshold', 0.2)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
self.tracker = DeepSORTTracker(
budget=budget,
max_age=max_age,
max_iou_distance=max_iou_distance,
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio, )
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1)
self.tracker = JDETracker(
use_byte=use_byte,
det_thresh=det_thresh,
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
match_thres=match_thres,
conf_thres=conf_thres,
low_conf_thres=low_conf_thres)
low_conf_thres=low_conf_thres, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes_num = result['boxes_num']
if np_boxes_num[0] <= 0:
print('[WARNNING] No object detected.')
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result = {k: v for k, v in result.items() if v is not None}
return result
def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes']
pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image']
ori_image_shape = ori_image.shape[:2]
pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
if len(keep_idx[0]) == 0:
det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
det_results['embeddings'] = None
return det_results
pred_dets = pred_dets[keep_idx[0]]
pred_xyxys = pred_dets[:, 2:6]
w, h = self.tracker.input_size
crops = get_crops(pred_xyxys, ori_image, w, h)
# to keep fast speed, only use topk crops
crops = crops[:50] # reid_batch_size
det_results['crops'] = np.array(crops).astype('float32')
det_results['boxes'] = pred_dets[:50]
input_names = self.reid_predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.reid_predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(det_results[input_names[i]])
# model prediction
for i in range(repeats):
self.reid_predictor.run()
output_names = self.reid_predictor.get_output_names()
feature_tensor = self.reid_predictor.get_output_handle(output_names[
0])
pred_embs = feature_tensor.copy_to_cpu()
det_results['embeddings'] = pred_embs
return det_results
def tracking(self, det_results):
pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = None
pred_embs = det_results.get('embeddings', None)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
self.tracker.predict()
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
......@@ -126,19 +259,26 @@ class SDE_Detector(Detector):
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
mot_results = []
visual=True,
seq_name=None):
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
mot_results = []
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
frame, _ = decode_image(img_file, {})
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
......@@ -159,10 +299,16 @@ class SDE_Detector(Detector):
self.det_times.postprocess_time_s.end()
# tracking
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
if self.use_reid:
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
......@@ -186,16 +332,26 @@ class SDE_Detector(Detector):
# tracking process
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
mot_results.append([online_tlwhs, online_scores, online_ids])
if visual:
if frame_id % 10 == 0:
if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
if isinstance(online_tlwhs, defaultdict):
im = plot_tracking_dict(
frame,
num_classes,
......@@ -204,14 +360,19 @@ class SDE_Detector(Detector):
online_scores,
frame_id=frame_id,
ids2names=[])
seq_name = image_list[0].split('/')[-2]
else:
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id)
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
def predict_video(self, video_file, camera_id):
......@@ -231,13 +392,17 @@ class SDE_Detector(Detector):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1
timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes
results = defaultdict(list)
num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels
while (1):
ret, frame = capture.read()
if not ret:
......@@ -247,16 +412,32 @@ class SDE_Detector(Detector):
frame_id += 1
timer.tic()
mot_results = self.predict_image([frame], visual=False)
seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc()
# bs=1 in MOT model
online_tlwhs, online_scores, online_ids = mot_results[0]
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.duration
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps)
else:
# use ByteTracker, support multiple class
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
im = plot_tracking_dict(
frame,
num_classes,
......@@ -265,13 +446,19 @@ class SDE_Detector(Detector):
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=[])
ids2names=ids2names)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()
......@@ -282,18 +469,20 @@ def main():
arch = yml_conf['arch']
detector = SDE_Detector(
FLAGS.model_dir,
FLAGS.tracker_config,
tracker_config=FLAGS.tracker_config,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts, )
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......@@ -303,7 +492,9 @@ def main():
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
seq_name = FLAGS.image_dir.split('/')[-1]
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
......
# config of tracker for MOT SDE Detector, use ByteTracker as default.
# The tracker of MOT JDE Detector is exported together with the model.
# config of tracker for MOT SDE Detector, use 'JDETracker' as default.
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker:
use_byte: true
type: JDETracker # 'JDETracker' or 'DeepSORTTracker'
# BYTETracker
JDETracker:
use_byte: True
det_thresh: 0.3
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6
min_box_area: 0
vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker:
input_size: [64, 192]
min_box_area: 0
vertical_ratio: -1
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
......@@ -44,7 +44,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set
bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
......@@ -70,8 +70,8 @@ class JDETracker(object):
num_classes=1,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
min_box_area=0,
vertical_ratio=0,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
......@@ -167,9 +167,8 @@ class JDETracker(object):
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
30, temp_feat)
for (tlbrs, temp_feat
) in zip(pred_dets_cls, pred_embs_cls)
30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls, pred_embs_cls)
]
else:
detections = []
......@@ -244,15 +243,13 @@ class JDETracker(object):
for tlbrs in pred_dets_cls_second
]
else:
pred_embs_cls_second = pred_embs_dict[cls_id][inds_second]
pred_embs_cls_second = pred_embs_dict[cls_id][
inds_second]
detections_second = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]),
tlbrs[1],
cls_id,
30,
temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls_second, pred_embs_cls_second)
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
zip(pred_dets_cls_second, pred_embs_cls_second)
]
else:
detections_second = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册