diff --git a/deploy/python/mot_infer.py b/deploy/python/mot_infer.py index 94e4e5f8beee0083eb02c141ec6bc00a95dc5459..e13b8f32071afc6144747d4298e1cdb617fe34b7 100644 --- a/deploy/python/mot_infer.py +++ b/deploy/python/mot_infer.py @@ -93,7 +93,7 @@ class MOT_Detector(object): inputs = create_inputs(im, im_info) return inputs - def postprocess(self, pred_dets, pred_embs): + def postprocess(self, pred_dets, pred_embs, threshold): online_targets = self.tracker.update(pred_dets, pred_embs) online_tlwhs, online_ids = [], [] online_scores = [] @@ -101,6 +101,7 @@ class MOT_Detector(object): tlwh = t.tlwh tid = t.track_id tscore = t.score + if tscore < threshold: continue vertical = tlwh[2] / tlwh[3] > 1.6 if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: online_tlwhs.append(tlwh) @@ -137,8 +138,8 @@ class MOT_Detector(object): self.det_times.inference_time_s.end(repeats=repeats) self.det_times.postprocess_time_s.start() - online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets, - pred_embs) + online_tlwhs, online_scores, online_ids = self.postprocess( + pred_dets, pred_embs, threshold) self.det_times.postprocess_time_s.end() self.det_times.img_num += 1 return online_tlwhs, online_scores, online_ids @@ -363,7 +364,8 @@ def predict_video(detector, camera_id): online_ids, online_scores, frame_id=frame_id, - fps=fps) + fps=fps, + threhold=FLAGS.threshold) if FLAGS.save_images: save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) if not os.path.exists(save_dir): diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index cd3060ebac8c6356265edbeca6ee543aab1e7251..fb803948448a14805a84f47bd937cc2fd495a6c5 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -112,7 +112,8 @@ class Tracker(object): dataloader, save_dir=None, show_image=False, - frame_rate=30): + frame_rate=30, + draw_threshold=0): if save_dir: if not os.path.exists(save_dir): os.makedirs(save_dir) tracker = self.model.tracker @@ -140,6 +141,7 @@ class Tracker(object): tlwh = t.tlwh tid = t.track_id tscore = t.score + if tscore < draw_threshold: continue vertical = tlwh[2] / tlwh[3] > 1.6 if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical: online_tlwhs.append(tlwh) @@ -162,7 +164,8 @@ class Tracker(object): save_dir=None, show_image=False, frame_rate=30, - det_file=''): + det_file='', + draw_threshold=0): if save_dir: if not os.path.exists(save_dir): os.makedirs(save_dir) tracker = self.model.tracker @@ -191,6 +194,7 @@ class Tracker(object): dets = dets_list[frame_id] bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32') pred_scores = paddle.to_tensor(dets['score'], dtype='float32') + if pred_scores < draw_threshold: continue if bbox_tlwh.shape[0] > 0: pred_bboxes = paddle.concat( (bbox_tlwh[:, 0:2], @@ -343,7 +347,8 @@ class Tracker(object): save_images=False, save_videos=True, show_image=False, - det_results_dir=''): + det_results_dir='', + draw_threshold=0.5): if not os.path.exists(output_dir): os.makedirs(output_dir) result_root = os.path.join(output_dir, 'mot_results') if not os.path.exists(result_root): os.makedirs(result_root) @@ -369,7 +374,8 @@ class Tracker(object): dataloader, save_dir=save_dir, show_image=show_image, - frame_rate=frame_rate) + frame_rate=frame_rate, + draw_threshold=draw_threshold) elif model_type in ['DeepSORT']: results, nf, ta, tc = self._eval_seq_sde( dataloader, @@ -377,7 +383,8 @@ class Tracker(object): show_image=show_image, frame_rate=frame_rate, det_file=os.path.join(det_results_dir, - '{}.txt'.format(seq))) + '{}.txt'.format(seq)), + draw_threshold=draw_threshold) else: raise ValueError(model_type) diff --git a/tools/infer_mot.py b/tools/infer_mot.py index 2067375776c58ac98b5f02282becefd0c5d07cce..407e9aae4aabcab3e718c050a2137a63f8e7e7c8 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -68,6 +68,11 @@ def parse_args(): '--show_image', action='store_true', help='Show tracking results (image).') + parser.add_argument( + "--draw_threshold", + type=float, + default=0.5, + help="Threshold to reserve the result for visualization.") args = parser.parse_args() return args @@ -94,7 +99,8 @@ def run(FLAGS, cfg): save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, - det_results_dir=FLAGS.det_results_dir) + det_results_dir=FLAGS.det_results_dir, + draw_threshold=FLAGS.draw_threshold) def main():