未验证 提交 822c7933 编写于 作者: F Feng Ni 提交者: GitHub

[cherry-pick][MOT] fix deploy python mot infer (#5630)

* fix deploy python mot infer, fix cfgs

* fix doc, test=document_fix
上级 20b2f101
...@@ -53,4 +53,4 @@ JDETracker: ...@@ -53,4 +53,4 @@ JDETracker:
det_thresh: 0.3 det_thresh: 0.3
track_buffer: 30 track_buffer: 30
min_box_area: 200 min_box_area: 200
motion: KalmanFilter vertical_ratio: 1.6 # for pedestrian
...@@ -11,8 +11,8 @@ JDETracker: ...@@ -11,8 +11,8 @@ JDETracker:
conf_thres: 0.6 conf_thres: 0.6
low_conf_thres: 0.1 low_conf_thres: 0.1
match_thres: 0.9 match_thres: 0.9
min_box_area: 100 min_box_area: 0
vertical_ratio: 1.6 # for pedestrian vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker: DeepSORTTracker:
input_size: [64, 192] input_size: [64, 192]
......
...@@ -38,7 +38,7 @@ class JDETracker(object): ...@@ -38,7 +38,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking. 1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked tracked_thresh (float): linear assignment threshold of tracked
stracks and detections stracks and detections
...@@ -64,8 +64,8 @@ class JDETracker(object): ...@@ -64,8 +64,8 @@ class JDETracker(object):
num_classes=1, num_classes=1,
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
min_box_area=200, min_box_area=0,
vertical_ratio=1.6, vertical_ratio=0,
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
...@@ -161,9 +161,8 @@ class JDETracker(object): ...@@ -161,9 +161,8 @@ class JDETracker(object):
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id, STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
30, temp_feat) 30, temp_feat) for (tlbrs, temp_feat) in
for (tlbrs, temp_feat zip(pred_dets_cls, pred_embs_cls)
) in zip(pred_dets_cls, pred_embs_cls)
] ]
else: else:
detections = [] detections = []
...@@ -238,15 +237,13 @@ class JDETracker(object): ...@@ -238,15 +237,13 @@ class JDETracker(object):
for tlbrs in pred_dets_cls_second for tlbrs in pred_dets_cls_second
] ]
else: else:
pred_embs_cls_second = pred_embs_dict[cls_id][inds_second] pred_embs_cls_second = pred_embs_dict[cls_id][
inds_second]
detections_second = [ detections_second = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
tlbrs[1], cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
cls_id, zip(pred_dets_cls_second, pred_embs_cls_second)
30,
temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls_second, pred_embs_cls_second)
] ]
else: else:
detections_second = [] detections_second = []
......
...@@ -112,8 +112,8 @@ class JDE_Detector(Detector): ...@@ -112,8 +112,8 @@ class JDE_Detector(Detector):
# tracker config # tracker config
assert self.pred_config.tracker, "The exported JDE Detector model should have tracker." assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
cfg = self.pred_config.tracker cfg = self.pred_config.tracker
min_box_area = cfg.get('min_box_area', 200) min_box_area = cfg.get('min_box_area', 0.0)
vertical_ratio = cfg.get('vertical_ratio', 1.6) vertical_ratio = cfg.get('vertical_ratio', 0.0)
conf_thres = cfg.get('conf_thres', 0.0) conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7) tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean') metric_type = cfg.get('metric_type', 'euclidean')
...@@ -164,7 +164,7 @@ class JDE_Detector(Detector): ...@@ -164,7 +164,7 @@ class JDE_Detector(Detector):
repeats (int): repeats number for prediction repeats (int): repeats number for prediction
Returns: Returns:
result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box, result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[x_min, y_min, x_max, y_max, score, class] matix element:[class, score, x_min, y_min, x_max, y_max]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray: FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128] shape: [N, 128]
''' '''
......
...@@ -165,8 +165,8 @@ class SDE_Detector(Detector): ...@@ -165,8 +165,8 @@ class SDE_Detector(Detector):
# use ByteTracker # use ByteTracker
use_byte = cfg.get('use_byte', False) use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3) det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 200) min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 1.6) vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9) match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6) conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1) low_conf_thres = cfg.get('low_conf_thres', 0.1)
...@@ -194,7 +194,7 @@ class SDE_Detector(Detector): ...@@ -194,7 +194,7 @@ class SDE_Detector(Detector):
return result return result
def reidprocess(self, det_results, repeats=1): def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes'] pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
pred_xyxys = pred_dets[:, 2:6] pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image'] ori_image = det_results['ori_image']
...@@ -234,7 +234,7 @@ class SDE_Detector(Detector): ...@@ -234,7 +234,7 @@ class SDE_Detector(Detector):
return det_results return det_results
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['boxes'] pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results.get('embeddings', None) pred_embs = det_results.get('embeddings', None)
if self.use_deepsort_tracker: if self.use_deepsort_tracker:
......
...@@ -11,8 +11,8 @@ JDETracker: ...@@ -11,8 +11,8 @@ JDETracker:
conf_thres: 0.6 conf_thres: 0.6
low_conf_thres: 0.1 low_conf_thres: 0.1
match_thres: 0.9 match_thres: 0.9
min_box_area: 100 min_box_area: 0
vertical_ratio: 1.6 # for pedestrian vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker: DeepSORTTracker:
input_size: [64, 192] input_size: [64, 192]
......
...@@ -3,27 +3,76 @@ ...@@ -3,27 +3,76 @@
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。 在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
主要包含两个步骤: Python端预测部署主要包含两个步骤:
- 导出预测模型 - 导出预测模型
- 基于Python进行预测 - 基于Python进行预测
## 1. 导出预测模型 ## 1. 导出预测模型
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md) PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](../deploy/EXPORT_MODEL.md),例如
```bash
# 导出YOLOv3检测模型
python tools/export_model.py -c configs/yolov3/yolov3_darknet53_270e_coco.yml --output_dir=./inference_model \
-o weights=https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams
# 导出HigherHRNet(bottom-up)关键点检测模型
python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams
# 导出HRNet(top-down)关键点检测模型
python tools/export_model.py -c configs/keypoint/hrnet/hrnet_w32_384x288.yml -o weights=https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams
# 导出FairMOT多目标跟踪模型
python tools/export_model.py -c configs/mot/fairmot/fairmot_dla34_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608.pdparams
# 导出ByteTrack多目标跟踪模型(相当于只导出检测器)
python tools/export_model.py -c configs/mot/bytetrack/detector/ppyoloe_crn_l_36e_640x640_mot17half.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
```
导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。 导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。
## 2. 基于Python的预测
## 2. 基于Python的预测
### 2.1 通用检测
在终端输入以下命令进行预测:
```bash
python deploy/python/infer.py --model_dir=./output_inference/yolov3_darknet53_270e_coco --image_file=./demo/000000014439.jpg --device=GPU
```
### 2.2 关键点检测
在终端输入以下命令进行预测: 在终端输入以下命令进行预测:
```bash
# keypoint top-down(HRNet)/bottom-up(HigherHRNet)单独推理,该模式下top-down模型HRNet只支持单人截图预测
python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=./demo/hrnet_demo.jpg --device=GPU --threshold=0.5
python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=./demo/000000014439_640x640.jpg --device=GPU --threshold=0.5
# detector 检测 + keypoint top-down模型联合部署(联合推理只支持top-down关键点模型)
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/yolov3_darknet53_270e_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file={your video name}.mp4 --device=GPU
```
**注意:**
- 关键点检测模型导出和预测具体可参照[keypoint](../../configs/keypoint/README.md),可分别在各个模型的文档中查找具体用法;
- 此目录下的关键点检测部署为基础前向功能,更多关键点检测功能可使用PP-Human项目,参照[pphuman](../pphuman/README.md)
### 2.3 多目标跟踪
在终端输入以下命令进行预测:
```bash ```bash
python deploy/python/infer.py --model_dir=./output_inference/yolov3_mobilenet_v1_roadsign --image_file=./demo/road554.png --device=GPU # FairMOT跟踪
python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608 --video_file={your video name}.mp4 --device=GPU
# ByteTrack跟踪
python deploy/python/mot_sde_infer.py --model_dir=output_inference/ppyoloe_crn_l_36e_640x640_mot17half/ --tracker_config=deploy/python/tracker_config.yml --video_file={your video name}.mp4 --device=GPU --scaled=True
# FairMOT多目标跟踪联合HRNet关键点检测(联合推理只支持top-down关键点模型)
python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inference/fairmot_dla34_30e_1088x608/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file={your video name}.mp4 --device=GPU
``` ```
**注意:**
- 多目标跟踪模型导出和预测具体可参照[mot]](../../configs/mot/README.md),可分别在各个模型的文档中查找具体用法;
- 此目录下的跟踪部署为基础前向功能以及联合关键点部署,更多跟踪功能可使用PP-Human项目,参照[pphuman](../pphuman/README.md),或PP-Tracking项目(绘制轨迹、出入口流量计数),参照[pptracking](../pptracking/README.md)
参数说明如下: 参数说明如下:
| 参数 | 是否必须|含义 | | 参数 | 是否必须|含义 |
......
...@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path) ...@@ -32,7 +32,7 @@ sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.visualize import plot_tracking, plot_tracking_dict from pptracking.python.mot.visualize import plot_tracking_dict
# Global dictionary # Global dictionary
MOT_JDE_SUPPORT_MODELS = { MOT_JDE_SUPPORT_MODELS = {
...@@ -55,9 +55,14 @@ class JDE_Detector(Detector): ...@@ -55,9 +55,14 @@ class JDE_Detector(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
""" """
def __init__(self, def __init__(
self,
model_dir, model_dir,
tracker_config=None, tracker_config=None,
device='CPU', device='CPU',
...@@ -70,7 +75,9 @@ class JDE_Detector(Detector): ...@@ -70,7 +75,9 @@ class JDE_Detector(Detector):
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5): threshold=0.5,
save_images=False,
save_mot_txts=False, ):
super(JDE_Detector, self).__init__( super(JDE_Detector, self).__init__(
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
...@@ -84,6 +91,8 @@ class JDE_Detector(Detector): ...@@ -84,6 +91,8 @@ class JDE_Detector(Detector):
enable_mkldnn=enable_mkldnn, enable_mkldnn=enable_mkldnn,
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1." assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True) self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels) self.num_classes = len(self.pred_config.labels)
...@@ -91,8 +100,8 @@ class JDE_Detector(Detector): ...@@ -91,8 +100,8 @@ class JDE_Detector(Detector):
# tracker config # tracker config
assert self.pred_config.tracker, "The exported JDE Detector model should have tracker." assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
cfg = self.pred_config.tracker cfg = self.pred_config.tracker
min_box_area = cfg.get('min_box_area', 200) min_box_area = cfg.get('min_box_area', 0.0)
vertical_ratio = cfg.get('vertical_ratio', 1.6) vertical_ratio = cfg.get('vertical_ratio', 0.0)
conf_thres = cfg.get('conf_thres', 0.0) conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7) tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean') metric_type = cfg.get('metric_type', 'euclidean')
...@@ -115,7 +124,7 @@ class JDE_Detector(Detector): ...@@ -115,7 +124,7 @@ class JDE_Detector(Detector):
return result return result
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['pred_dets'] # 'cls_id, score, x0, y0, x1, y1' pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
pred_embs = det_results['pred_embs'] pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
...@@ -164,7 +173,8 @@ class JDE_Detector(Detector): ...@@ -164,7 +173,8 @@ class JDE_Detector(Detector):
image_list, image_list,
run_benchmark=False, run_benchmark=False,
repeats=1, repeats=1,
visual=True): visual=True,
seq_name=None):
mot_results = [] mot_results = []
num_classes = self.num_classes num_classes = self.num_classes
image_list.sort() image_list.sort()
...@@ -225,7 +235,7 @@ class JDE_Detector(Detector): ...@@ -225,7 +235,7 @@ class JDE_Detector(Detector):
self.det_times.img_num += 1 self.det_times.img_num += 1
if visual: if visual:
if frame_id % 10 == 0: if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
...@@ -237,6 +247,7 @@ class JDE_Detector(Detector): ...@@ -237,6 +247,7 @@ class JDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
ids2names=ids2names) ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -264,7 +275,8 @@ class JDE_Detector(Detector): ...@@ -264,7 +275,8 @@ class JDE_Detector(Detector):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1 frame_id = 1
...@@ -282,7 +294,9 @@ class JDE_Detector(Detector): ...@@ -282,7 +294,9 @@ class JDE_Detector(Detector):
frame_id += 1 frame_id += 1
timer.tic() timer.tic()
mot_results = self.predict_image([frame], visual=False) seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
...@@ -307,20 +321,33 @@ class JDE_Detector(Detector): ...@@ -307,20 +321,33 @@ class JDE_Detector(Detector):
cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
writer.release() writer.release()
def main(): def main():
detector = JDE_Detector( detector = JDE_Detector(
FLAGS.model_dir, FLAGS.model_dir,
tracker_config=None,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts)
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
...@@ -24,7 +24,7 @@ from collections import defaultdict ...@@ -24,7 +24,7 @@ from collections import defaultdict
from mot_keypoint_unite_utils import argsparser from mot_keypoint_unite_utils import argsparser
from preprocess import decode_image from preprocess import decode_image
from infer import print_arguments, get_test_images from infer import print_arguments, get_test_images, bench_log
from mot_sde_infer import SDE_Detector from mot_sde_infer import SDE_Detector
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
...@@ -39,7 +39,7 @@ import sys ...@@ -39,7 +39,7 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from pptracking.python.visualize import plot_tracking, plot_tracking_dict from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.utils import MOTTimer as FPSTimer from pptracking.python.mot.utils import MOTTimer as FPSTimer
...@@ -92,7 +92,7 @@ def mot_topdown_unite_predict(mot_detector, ...@@ -92,7 +92,7 @@ def mot_topdown_unite_predict(mot_detector,
keypoint_res = predict_with_given_det( keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size, image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) FLAGS.run_benchmark)
if save_res: if save_res:
store_res.append([ store_res.append([
...@@ -146,7 +146,7 @@ def mot_topdown_unite_predict_video(mot_detector, ...@@ -146,7 +146,7 @@ def mot_topdown_unite_predict_video(mot_detector,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer() timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
...@@ -179,7 +179,7 @@ def mot_topdown_unite_predict_video(mot_detector, ...@@ -179,7 +179,7 @@ def mot_topdown_unite_predict_video(mot_detector,
timer_kp.tic() timer_kp.tic()
keypoint_res = predict_with_given_det( keypoint_res = predict_with_given_det(
frame, results, topdown_keypoint_detector, keypoint_batch_size, frame, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) FLAGS.run_benchmark)
timer_kp.toc() timer_kp.toc()
timer_mot_kp.toc() timer_mot_kp.toc()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,15 +23,15 @@ import paddle ...@@ -23,15 +23,15 @@ import paddle
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
# add python path # add python path
import sys import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
...@@ -50,7 +50,11 @@ class SDE_Detector(Detector): ...@@ -50,7 +50,11 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
""" """
def __init__(self, def __init__(self,
...@@ -66,7 +70,10 @@ class SDE_Detector(Detector): ...@@ -66,7 +70,10 @@ class SDE_Detector(Detector):
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5): threshold=0.5,
save_images=False,
save_mot_txts=False,
reid_model_dir=None):
super(SDE_Detector, self).__init__( super(SDE_Detector, self).__init__(
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
...@@ -80,37 +87,163 @@ class SDE_Detector(Detector): ...@@ -80,37 +87,163 @@ class SDE_Detector(Detector):
enable_mkldnn=enable_mkldnn, enable_mkldnn=enable_mkldnn,
output_dir=output_dir, output_dir=output_dir,
threshold=threshold, ) threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1." assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True) self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels) self.num_classes = len(self.pred_config.labels)
# tracker config # reid config
self.use_reid = False if reid_model_dir is None else True
if self.use_reid:
self.reid_pred_config = self.set_config(reid_model_dir)
self.reid_predictor, self.config = load_predictor(
reid_model_dir,
run_mode=run_mode,
batch_size=50, # reid_batch_size
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
else:
self.reid_pred_config = None
self.reid_predictor = None
assert tracker_config is not None, 'Note that tracker_config should be set.'
self.tracker_config = tracker_config self.tracker_config = tracker_config
cfg = yaml.safe_load(open(self.tracker_config))['tracker'] tracker_cfg = yaml.safe_load(open(self.tracker_config))
min_box_area = cfg.get('min_box_area', 200) cfg = tracker_cfg[tracker_cfg['type']]
vertical_ratio = cfg.get('vertical_ratio', 1.6)
use_byte = cfg.get('use_byte', True) # tracker config
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
budget = cfg.get('budget', 100)
max_age = cfg.get('max_age', 30)
max_iou_distance = cfg.get('max_iou_distance', 0.7)
matching_threshold = cfg.get('matching_threshold', 0.2)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
self.tracker = DeepSORTTracker(
budget=budget,
max_age=max_age,
max_iou_distance=max_iou_distance,
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio, )
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9) match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6) conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1) low_conf_thres = cfg.get('low_conf_thres', 0.1)
self.tracker = JDETracker( self.tracker = JDETracker(
use_byte=use_byte, use_byte=use_byte,
det_thresh=det_thresh,
num_classes=self.num_classes, num_classes=self.num_classes,
min_box_area=min_box_area, min_box_area=min_box_area,
vertical_ratio=vertical_ratio, vertical_ratio=vertical_ratio,
match_thres=match_thres, match_thres=match_thres,
conf_thres=conf_thres, conf_thres=conf_thres,
low_conf_thres=low_conf_thres) low_conf_thres=low_conf_thres, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes_num = result['boxes_num']
if np_boxes_num[0] <= 0:
print('[WARNNING] No object detected.')
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result = {k: v for k, v in result.items() if v is not None}
return result
def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes']
pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image']
ori_image_shape = ori_image.shape[:2]
pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
if len(keep_idx[0]) == 0:
det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
det_results['embeddings'] = None
return det_results
pred_dets = pred_dets[keep_idx[0]]
pred_xyxys = pred_dets[:, 2:6]
w, h = self.tracker.input_size
crops = get_crops(pred_xyxys, ori_image, w, h)
# to keep fast speed, only use topk crops
crops = crops[:50] # reid_batch_size
det_results['crops'] = np.array(crops).astype('float32')
det_results['boxes'] = pred_dets[:50]
input_names = self.reid_predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.reid_predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(det_results[input_names[i]])
# model prediction
for i in range(repeats):
self.reid_predictor.run()
output_names = self.reid_predictor.get_output_names()
feature_tensor = self.reid_predictor.get_output_handle(output_names[
0])
pred_embs = feature_tensor.copy_to_cpu()
det_results['embeddings'] = pred_embs
return det_results
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1' pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = None pred_embs = det_results.get('embeddings', None)
online_targets_dict = self.tracker.update(pred_dets, pred_embs) if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
self.tracker.predict()
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list) online_tlwhs = defaultdict(list)
online_scores = defaultdict(list) online_scores = defaultdict(list)
online_ids = defaultdict(list) online_ids = defaultdict(list)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id] online_targets = online_targets_dict[cls_id]
for t in online_targets: for t in online_targets:
...@@ -126,19 +259,26 @@ class SDE_Detector(Detector): ...@@ -126,19 +259,26 @@ class SDE_Detector(Detector):
online_ids[cls_id].append(tid) online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore) online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
def predict_image(self, def predict_image(self,
image_list, image_list,
run_benchmark=False, run_benchmark=False,
repeats=1, repeats=1,
visual=True): visual=True,
mot_results = [] seq_name=None):
num_classes = self.num_classes num_classes = self.num_classes
image_list.sort() image_list.sort()
ids2names = self.pred_config.labels ids2names = self.pred_config.labels
mot_results = []
for frame_id, img_file in enumerate(image_list): for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model batch_image_list = [img_file] # bs=1 in MOT model
frame, _ = decode_image(img_file, {})
if run_benchmark: if run_benchmark:
# preprocess # preprocess
inputs = self.preprocess(batch_image_list) # warmup inputs = self.preprocess(batch_image_list) # warmup
...@@ -159,10 +299,16 @@ class SDE_Detector(Detector): ...@@ -159,10 +299,16 @@ class SDE_Detector(Detector):
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
# tracking # tracking
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
result_warmup = self.tracking(det_result) result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start() self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking( if self.use_reid:
det_result) det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end() self.det_times.tracking_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
...@@ -186,16 +332,26 @@ class SDE_Detector(Detector): ...@@ -186,16 +332,26 @@ class SDE_Detector(Detector):
# tracking process # tracking process
self.det_times.tracking_time_s.start() self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking( if self.use_reid:
det_result) det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end() self.det_times.tracking_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
mot_results.append([online_tlwhs, online_scores, online_ids])
if visual: if visual:
if frame_id % 10 == 0: if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
if isinstance(online_tlwhs, defaultdict):
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
...@@ -204,14 +360,19 @@ class SDE_Detector(Detector): ...@@ -204,14 +360,19 @@ class SDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
ids2names=[]) ids2names=[])
seq_name = image_list[0].split('/')[-2] else:
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id)
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
cv2.imwrite( cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results return mot_results
def predict_video(self, video_file, camera_id): def predict_video(self, video_file, camera_id):
...@@ -231,13 +392,17 @@ class SDE_Detector(Detector): ...@@ -231,13 +392,17 @@ class SDE_Detector(Detector):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1 frame_id = 1
timer = MOTTimer() timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes results = defaultdict(list)
num_classes = self.num_classes num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -247,16 +412,32 @@ class SDE_Detector(Detector): ...@@ -247,16 +412,32 @@ class SDE_Detector(Detector):
frame_id += 1 frame_id += 1
timer.tic() timer.tic()
mot_results = self.predict_image([frame], visual=False) seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc() timer.toc()
# bs=1 in MOT model
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.duration fps = 1. / timer.duration
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps)
else:
# use ByteTracker, support multiple class
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
...@@ -265,13 +446,19 @@ class SDE_Detector(Detector): ...@@ -265,13 +446,19 @@ class SDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps, fps=fps,
ids2names=[]) ids2names=ids2names)
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release() writer.release()
...@@ -282,18 +469,20 @@ def main(): ...@@ -282,18 +469,20 @@ def main():
arch = yml_conf['arch'] arch = yml_conf['arch']
detector = SDE_Detector( detector = SDE_Detector(
FLAGS.model_dir, FLAGS.model_dir,
FLAGS.tracker_config, tracker_config=FLAGS.tracker_config,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size, batch_size=1,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold, threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir) save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts, )
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
...@@ -303,7 +492,9 @@ def main(): ...@@ -303,7 +492,9 @@ def main():
if FLAGS.image_dir is None and FLAGS.image_file is not None: if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models." assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10) seq_name = FLAGS.image_dir.split('/')[-1]
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
......
# config of tracker for MOT SDE Detector, use ByteTracker as default. # config of tracker for MOT SDE Detector, use 'JDETracker' as default.
# The tracker of MOT JDE Detector is exported together with the model. # The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking. # Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker:
use_byte: true type: JDETracker # 'JDETracker' or 'DeepSORTTracker'
# BYTETracker
JDETracker:
use_byte: True
det_thresh: 0.3
conf_thres: 0.6 conf_thres: 0.6
low_conf_thres: 0.1 low_conf_thres: 0.1
match_thres: 0.9 match_thres: 0.9
min_box_area: 100 min_box_area: 0
vertical_ratio: 1.6 vertical_ratio: 0 # 1.6 for pedestrian
DeepSORTTracker:
input_size: [64, 192]
min_box_area: 0
vertical_ratio: -1
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
...@@ -44,7 +44,7 @@ class JDETracker(object): ...@@ -44,7 +44,7 @@ class JDETracker(object):
track_buffer (int): buffer for tracker track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results. If set <0 means no need to filter bboxes,usually set bad results. If set <= 0 means no need to filter bboxes,usually set
1.6 for pedestrian tracking. 1.6 for pedestrian tracking.
tracked_thresh (float): linear assignment threshold of tracked tracked_thresh (float): linear assignment threshold of tracked
stracks and detections stracks and detections
...@@ -70,8 +70,8 @@ class JDETracker(object): ...@@ -70,8 +70,8 @@ class JDETracker(object):
num_classes=1, num_classes=1,
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
min_box_area=200, min_box_area=0,
vertical_ratio=1.6, vertical_ratio=0,
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
...@@ -167,9 +167,8 @@ class JDETracker(object): ...@@ -167,9 +167,8 @@ class JDETracker(object):
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id, STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
30, temp_feat) 30, temp_feat) for (tlbrs, temp_feat) in
for (tlbrs, temp_feat zip(pred_dets_cls, pred_embs_cls)
) in zip(pred_dets_cls, pred_embs_cls)
] ]
else: else:
detections = [] detections = []
...@@ -244,15 +243,13 @@ class JDETracker(object): ...@@ -244,15 +243,13 @@ class JDETracker(object):
for tlbrs in pred_dets_cls_second for tlbrs in pred_dets_cls_second
] ]
else: else:
pred_embs_cls_second = pred_embs_dict[cls_id][inds_second] pred_embs_cls_second = pred_embs_dict[cls_id][
inds_second]
detections_second = [ detections_second = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[2:6]), STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
tlbrs[1], cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
cls_id, zip(pred_dets_cls_second, pred_embs_cls_second)
30,
temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls_second, pred_embs_cls_second)
] ]
else: else:
detections_second = [] detections_second = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册