未验证 提交 ffc14930 编写于 作者: J JYChen 提交者: GitHub

fix error when all bbox are filered (#5564)

* fix error when all bbox are filered

* remove filter in function get_person_from_rect
上级 18f8eaf8
...@@ -36,10 +36,9 @@ KEYPOINT_SUPPORT_MODELS = { ...@@ -36,10 +36,9 @@ KEYPOINT_SUPPORT_MODELS = {
def predict_with_given_det(image, det_res, keypoint_detector, def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_batch_size, det_threshold, keypoint_batch_size, run_benchmark):
keypoint_threshold, run_benchmark):
rec_images, records, det_rects = keypoint_detector.get_person_from_rect( rec_images, records, det_rects = keypoint_detector.get_person_from_rect(
image, det_res, det_threshold) image, det_res)
keypoint_vector = [] keypoint_vector = []
score_vector = [] score_vector = []
...@@ -79,19 +78,20 @@ def topdown_unite_predict(detector, ...@@ -79,19 +78,20 @@ def topdown_unite_predict(detector,
detector.gpu_util += gu detector.gpu_util += gu
else: else:
results = detector.predict_image([image], visual=False) results = detector.predict_image([image], visual=False)
results = detector.filter_box(results, FLAGS.det_threshold)
if results['boxes_num'] > 0:
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.run_benchmark)
if results['boxes_num'] == 0: if save_res:
continue store_res.append([
i, keypoint_res['bbox'],
keypoint_res = predict_with_given_det( [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
image, results, topdown_keypoint_detector, keypoint_batch_size, ])
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) else:
results["keypoint"] = [[], []]
if save_res: keypoint_res = results
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
topdown_keypoint_detector.cpu_mem += cm topdown_keypoint_detector.cpu_mem += cm
...@@ -138,7 +138,7 @@ def topdown_unite_predict_video(detector, ...@@ -138,7 +138,7 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0 index = 0
store_res = [] store_res = []
...@@ -152,10 +152,14 @@ def topdown_unite_predict_video(detector, ...@@ -152,10 +152,14 @@ def topdown_unite_predict_video(detector,
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict_image([frame2], visual=False) results = detector.predict_image([frame2], visual=False)
results = detector.filter_box(results, FLAGS.det_threshold)
if results['boxes_num'] == 0:
writer.write(frame)
continue
keypoint_res = predict_with_given_det( keypoint_res = predict_with_given_det(
frame2, results, topdown_keypoint_detector, keypoint_batch_size, frame2, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) FLAGS.run_benchmark)
im = visualize_pose( im = visualize_pose(
frame, frame,
......
...@@ -95,12 +95,10 @@ class KeyPointDetector(Detector): ...@@ -95,12 +95,10 @@ class KeyPointDetector(Detector):
def set_config(self, model_dir): def set_config(self, model_dir):
return PredictConfig_KeyPoint(model_dir) return PredictConfig_KeyPoint(model_dir)
def get_person_from_rect(self, image, results, det_threshold=0.5): def get_person_from_rect(self, image, results):
# crop the person result from image # crop the person result from image
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
det_results = results['boxes'] valid_rects = results['boxes']
mask = det_results[:, 1] > det_threshold
valid_rects = det_results[mask]
rect_images = [] rect_images = []
new_rects = [] new_rects = []
org_rects = [] org_rects = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册