未验证 提交 fea33f2f 编写于 作者: G George Ni 提交者: GitHub

[MOT] fix deploy tracker diff (#3636)

* fix tracker trt, add tracker to infer_cfg yml

* fix doc and deploy tracker
上级 6eed850e
......@@ -95,7 +95,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deep
### 4. Using exported model for python inference
```bash
python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**Notes:**
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
......
......@@ -98,7 +98,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deep
### 4. 用导出的模型基于Python去预测
```bash
python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
python deploy/python/mot_sde_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
......
......@@ -345,6 +345,9 @@ class PredictConfig():
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
if 'mask' in yml_conf:
self.mask = yml_conf['mask']
self.tracker = None
if 'tracker' in yml_conf:
self.tracker = yml_conf['tracker']
self.print_config()
def check_model(self, yml_conf):
......
......@@ -79,7 +79,12 @@ class JDE_Detector(Detector):
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.tracker = JDETracker()
assert pred_config.tracker, "Tracking model should have tracker"
tp = pred_config.tracker
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
self.tracker = JDETracker(conf_thres=conf_thres, tracked_thresh=tracked_thresh, metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs)
......
......@@ -214,7 +214,7 @@ class SDE_ReID(object):
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
assert pred_config.tracker, "Tracking model should have tracker"
self.tracker = DeepSORTTracker()
def preprocess(self, crops):
......@@ -246,7 +246,6 @@ class SDE_ReID(object):
inputs = self.preprocess(crops)
self.det_times.preprocess_time_s.end()
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
......
......@@ -82,6 +82,11 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
return preprocess_list, label_list
def _parse_tracker(tracker_cfg):
tracker_params = {}
for k, v in tracker_cfg.items():
tracker_params.update({k: v})
return tracker_params
def _dump_infer_config(config, path, image_shape, model):
arch_state = False
......@@ -96,6 +101,13 @@ def _dump_infer_config(config, path, image_shape, model):
})
infer_arch = config['architecture']
if infer_arch in MOT_ARCH:
if infer_arch == 'DeepSORT':
tracker_cfg = config['DeepSORTTracker']
else:
tracker_cfg = config['JDETracker']
infer_cfg['tracker'] = _parse_tracker(tracker_cfg)
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch:
infer_cfg['arch'] = arch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册