未验证 提交 39ff9f2f 编写于 作者: W wangguanzhong 提交者: GitHub

fix score threshold in mot_infer (#3444)

上级 1264fde9
...@@ -93,7 +93,7 @@ class MOT_Detector(object): ...@@ -93,7 +93,7 @@ class MOT_Detector(object):
inputs = create_inputs(im, im_info) inputs = create_inputs(im, im_info)
return inputs return inputs
def postprocess(self, pred_dets, pred_embs): def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs) online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], [] online_tlwhs, online_ids = [], []
online_scores = [] online_scores = []
...@@ -101,6 +101,7 @@ class MOT_Detector(object): ...@@ -101,6 +101,7 @@ class MOT_Detector(object):
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6 vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
...@@ -137,8 +138,8 @@ class MOT_Detector(object): ...@@ -137,8 +138,8 @@ class MOT_Detector(object):
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start() self.det_times.postprocess_time_s.start()
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets, online_tlwhs, online_scores, online_ids = self.postprocess(
pred_embs) pred_dets, pred_embs, threshold)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids return online_tlwhs, online_scores, online_ids
...@@ -363,7 +364,8 @@ def predict_video(detector, camera_id): ...@@ -363,7 +364,8 @@ def predict_video(detector, camera_id):
online_ids, online_ids,
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps) fps=fps,
threhold=FLAGS.threshold)
if FLAGS.save_images: if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
......
...@@ -112,7 +112,8 @@ class Tracker(object): ...@@ -112,7 +112,8 @@ class Tracker(object):
dataloader, dataloader,
save_dir=None, save_dir=None,
show_image=False, show_image=False,
frame_rate=30): frame_rate=30,
draw_threshold=0):
if save_dir: if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker tracker = self.model.tracker
...@@ -140,6 +141,7 @@ class Tracker(object): ...@@ -140,6 +141,7 @@ class Tracker(object):
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < draw_threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6 vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical: if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
...@@ -162,7 +164,8 @@ class Tracker(object): ...@@ -162,7 +164,8 @@ class Tracker(object):
save_dir=None, save_dir=None,
show_image=False, show_image=False,
frame_rate=30, frame_rate=30,
det_file=''): det_file='',
draw_threshold=0):
if save_dir: if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker tracker = self.model.tracker
...@@ -191,6 +194,7 @@ class Tracker(object): ...@@ -191,6 +194,7 @@ class Tracker(object):
dets = dets_list[frame_id] dets = dets_list[frame_id]
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32') bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
pred_scores = paddle.to_tensor(dets['score'], dtype='float32') pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
if pred_scores < draw_threshold: continue
if bbox_tlwh.shape[0] > 0: if bbox_tlwh.shape[0] > 0:
pred_bboxes = paddle.concat( pred_bboxes = paddle.concat(
(bbox_tlwh[:, 0:2], (bbox_tlwh[:, 0:2],
...@@ -343,7 +347,8 @@ class Tracker(object): ...@@ -343,7 +347,8 @@ class Tracker(object):
save_images=False, save_images=False,
save_videos=True, save_videos=True,
show_image=False, show_image=False,
det_results_dir=''): det_results_dir='',
draw_threshold=0.5):
if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results') result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root) if not os.path.exists(result_root): os.makedirs(result_root)
...@@ -369,7 +374,8 @@ class Tracker(object): ...@@ -369,7 +374,8 @@ class Tracker(object):
dataloader, dataloader,
save_dir=save_dir, save_dir=save_dir,
show_image=show_image, show_image=show_image,
frame_rate=frame_rate) frame_rate=frame_rate,
draw_threshold=draw_threshold)
elif model_type in ['DeepSORT']: elif model_type in ['DeepSORT']:
results, nf, ta, tc = self._eval_seq_sde( results, nf, ta, tc = self._eval_seq_sde(
dataloader, dataloader,
...@@ -377,7 +383,8 @@ class Tracker(object): ...@@ -377,7 +383,8 @@ class Tracker(object):
show_image=show_image, show_image=show_image,
frame_rate=frame_rate, frame_rate=frame_rate,
det_file=os.path.join(det_results_dir, det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq))) '{}.txt'.format(seq)),
draw_threshold=draw_threshold)
else: else:
raise ValueError(model_type) raise ValueError(model_type)
......
...@@ -68,6 +68,11 @@ def parse_args(): ...@@ -68,6 +68,11 @@ def parse_args():
'--show_image', '--show_image',
action='store_true', action='store_true',
help='Show tracking results (image).') help='Show tracking results (image).')
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -94,7 +99,8 @@ def run(FLAGS, cfg): ...@@ -94,7 +99,8 @@ def run(FLAGS, cfg):
save_images=FLAGS.save_images, save_images=FLAGS.save_images,
save_videos=FLAGS.save_videos, save_videos=FLAGS.save_videos,
show_image=FLAGS.show_image, show_image=FLAGS.show_image,
det_results_dir=FLAGS.det_results_dir) det_results_dir=FLAGS.det_results_dir,
draw_threshold=FLAGS.draw_threshold)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册