未验证 提交 39ff9f2f 编写于 作者: W wangguanzhong 提交者: GitHub

fix score threshold in mot_infer (#3444)

上级 1264fde9
......@@ -93,7 +93,7 @@ class MOT_Detector(object):
inputs = create_inputs(im, im_info)
return inputs
def postprocess(self, pred_dets, pred_embs):
def postprocess(self, pred_dets, pred_embs, threshold):
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
......@@ -101,6 +101,7 @@ class MOT_Detector(object):
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
......@@ -137,8 +138,8 @@ class MOT_Detector(object):
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
online_tlwhs, online_scores, online_ids = self.postprocess(
pred_dets, pred_embs, threshold)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
......@@ -363,7 +364,8 @@ def predict_video(detector, camera_id):
online_ids,
online_scores,
frame_id=frame_id,
fps=fps)
fps=fps,
threhold=FLAGS.threshold)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
......
......@@ -112,7 +112,8 @@ class Tracker(object):
dataloader,
save_dir=None,
show_image=False,
frame_rate=30):
frame_rate=30,
draw_threshold=0):
if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker
......@@ -140,6 +141,7 @@ class Tracker(object):
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < draw_threshold: continue
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
......@@ -162,7 +164,8 @@ class Tracker(object):
save_dir=None,
show_image=False,
frame_rate=30,
det_file=''):
det_file='',
draw_threshold=0):
if save_dir:
if not os.path.exists(save_dir): os.makedirs(save_dir)
tracker = self.model.tracker
......@@ -191,6 +194,7 @@ class Tracker(object):
dets = dets_list[frame_id]
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
if pred_scores < draw_threshold: continue
if bbox_tlwh.shape[0] > 0:
pred_bboxes = paddle.concat(
(bbox_tlwh[:, 0:2],
......@@ -343,7 +347,8 @@ class Tracker(object):
save_images=False,
save_videos=True,
show_image=False,
det_results_dir=''):
det_results_dir='',
draw_threshold=0.5):
if not os.path.exists(output_dir): os.makedirs(output_dir)
result_root = os.path.join(output_dir, 'mot_results')
if not os.path.exists(result_root): os.makedirs(result_root)
......@@ -369,7 +374,8 @@ class Tracker(object):
dataloader,
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate)
frame_rate=frame_rate,
draw_threshold=draw_threshold)
elif model_type in ['DeepSORT']:
results, nf, ta, tc = self._eval_seq_sde(
dataloader,
......@@ -377,7 +383,8 @@ class Tracker(object):
show_image=show_image,
frame_rate=frame_rate,
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)))
'{}.txt'.format(seq)),
draw_threshold=draw_threshold)
else:
raise ValueError(model_type)
......
......@@ -68,6 +68,11 @@ def parse_args():
'--show_image',
action='store_true',
help='Show tracking results (image).')
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
args = parser.parse_args()
return args
......@@ -94,7 +99,8 @@ def run(FLAGS, cfg):
save_images=FLAGS.save_images,
save_videos=FLAGS.save_videos,
show_image=FLAGS.show_image,
det_results_dir=FLAGS.det_results_dir)
det_results_dir=FLAGS.det_results_dir,
draw_threshold=FLAGS.draw_threshold)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册