From fa52c445c51f12ced98b879365950deb59fdf394 Mon Sep 17 00:00:00 2001 From: JYChen Date: Sat, 2 Apr 2022 16:05:05 +0800 Subject: [PATCH] [Cherry-Pick] fix error when all bbox are filtered (#5571) * fix error when all bbox are filered * remove filter in function get_person_from_rect --- deploy/python/det_keypoint_unite_infer.py | 38 +++++++++++++---------- deploy/python/keypoint_infer.py | 6 ++-- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/deploy/python/det_keypoint_unite_infer.py b/deploy/python/det_keypoint_unite_infer.py index a82c2c58c..2b901dcaa 100644 --- a/deploy/python/det_keypoint_unite_infer.py +++ b/deploy/python/det_keypoint_unite_infer.py @@ -36,10 +36,9 @@ KEYPOINT_SUPPORT_MODELS = { def predict_with_given_det(image, det_res, keypoint_detector, - keypoint_batch_size, det_threshold, - keypoint_threshold, run_benchmark): + keypoint_batch_size, run_benchmark): rec_images, records, det_rects = keypoint_detector.get_person_from_rect( - image, det_res, det_threshold) + image, det_res) keypoint_vector = [] score_vector = [] @@ -79,19 +78,20 @@ def topdown_unite_predict(detector, detector.gpu_util += gu else: results = detector.predict_image([image], visual=False) + results = detector.filter_box(results, FLAGS.det_threshold) + if results['boxes_num'] > 0: + keypoint_res = predict_with_given_det( + image, results, topdown_keypoint_detector, keypoint_batch_size, + FLAGS.run_benchmark) - if results['boxes_num'] == 0: - continue - - keypoint_res = predict_with_given_det( - image, results, topdown_keypoint_detector, keypoint_batch_size, - FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) - - if save_res: - store_res.append([ - i, keypoint_res['bbox'], - [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]] - ]) + if save_res: + store_res.append([ + i, keypoint_res['bbox'], + [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]] + ]) + else: + results["keypoint"] = [[], []] + keypoint_res = results if FLAGS.run_benchmark: cm, gm, gu = get_current_memory_mb() topdown_keypoint_detector.cpu_mem += cm @@ -138,7 +138,7 @@ def topdown_unite_predict_video(detector, if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) out_path = os.path.join(FLAGS.output_dir, video_name) - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(* 'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 0 store_res = [] @@ -152,10 +152,14 @@ def topdown_unite_predict_video(detector, frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = detector.predict_image([frame2], visual=False) + results = detector.filter_box(results, FLAGS.det_threshold) + if results['boxes_num'] == 0: + writer.write(frame) + continue keypoint_res = predict_with_given_det( frame2, results, topdown_keypoint_detector, keypoint_batch_size, - FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) + FLAGS.run_benchmark) im = visualize_pose( frame, diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index e16ddd647..b87bac92c 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -95,12 +95,10 @@ class KeyPointDetector(Detector): def set_config(self, model_dir): return PredictConfig_KeyPoint(model_dir) - def get_person_from_rect(self, image, results, det_threshold=0.5): + def get_person_from_rect(self, image, results): # crop the person result from image self.det_times.preprocess_time_s.start() - det_results = results['boxes'] - mask = det_results[:, 1] > det_threshold - valid_rects = det_results[mask] + valid_rects = results['boxes'] rect_images = [] new_rects = [] org_rects = [] -- GitLab