diff --git a/deploy/python/det_keypoint_unite_infer.py b/deploy/python/det_keypoint_unite_infer.py index a82c2c58cefb0926958e911c4f2f8ee2fc5bfd75..2b901dcaa9dd7af8fcff4dcc99e15acd4ee9dede 100644 --- a/deploy/python/det_keypoint_unite_infer.py +++ b/deploy/python/det_keypoint_unite_infer.py @@ -36,10 +36,9 @@ KEYPOINT_SUPPORT_MODELS = { def predict_with_given_det(image, det_res, keypoint_detector, - keypoint_batch_size, det_threshold, - keypoint_threshold, run_benchmark): + keypoint_batch_size, run_benchmark): rec_images, records, det_rects = keypoint_detector.get_person_from_rect( - image, det_res, det_threshold) + image, det_res) keypoint_vector = [] score_vector = [] @@ -79,19 +78,20 @@ def topdown_unite_predict(detector, detector.gpu_util += gu else: results = detector.predict_image([image], visual=False) + results = detector.filter_box(results, FLAGS.det_threshold) + if results['boxes_num'] > 0: + keypoint_res = predict_with_given_det( + image, results, topdown_keypoint_detector, keypoint_batch_size, + FLAGS.run_benchmark) - if results['boxes_num'] == 0: - continue - - keypoint_res = predict_with_given_det( - image, results, topdown_keypoint_detector, keypoint_batch_size, - FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) - - if save_res: - store_res.append([ - i, keypoint_res['bbox'], - [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]] - ]) + if save_res: + store_res.append([ + i, keypoint_res['bbox'], + [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]] + ]) + else: + results["keypoint"] = [[], []] + keypoint_res = results if FLAGS.run_benchmark: cm, gm, gu = get_current_memory_mb() topdown_keypoint_detector.cpu_mem += cm @@ -138,7 +138,7 @@ def topdown_unite_predict_video(detector, if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) out_path = os.path.join(FLAGS.output_dir, video_name) - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(* 'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 0 store_res = [] @@ -152,10 +152,14 @@ def topdown_unite_predict_video(detector, frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = detector.predict_image([frame2], visual=False) + results = detector.filter_box(results, FLAGS.det_threshold) + if results['boxes_num'] == 0: + writer.write(frame) + continue keypoint_res = predict_with_given_det( frame2, results, topdown_keypoint_detector, keypoint_batch_size, - FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) + FLAGS.run_benchmark) im = visualize_pose( frame, diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index e16ddd647cf58a58bb9b4c8cb239fd9e3d472673..b87bac92c3958c23c392e0d43a688a1346c9744c 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -95,12 +95,10 @@ class KeyPointDetector(Detector): def set_config(self, model_dir): return PredictConfig_KeyPoint(model_dir) - def get_person_from_rect(self, image, results, det_threshold=0.5): + def get_person_from_rect(self, image, results): # crop the person result from image self.det_times.preprocess_time_s.start() - det_results = results['boxes'] - mask = det_results[:, 1] > det_threshold - valid_rects = det_results[mask] + valid_rects = results['boxes'] rect_images = [] new_rects = [] org_rects = []