未验证 提交 fa52c445 编写于 作者: J JYChen 提交者: GitHub

[Cherry-Pick] fix error when all bbox are filtered (#5571)

* fix error when all bbox are filered

* remove filter in function get_person_from_rect
上级 1bc50816
......@@ -36,10 +36,9 @@ KEYPOINT_SUPPORT_MODELS = {
def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_batch_size, det_threshold,
keypoint_threshold, run_benchmark):
keypoint_batch_size, run_benchmark):
rec_images, records, det_rects = keypoint_detector.get_person_from_rect(
image, det_res, det_threshold)
image, det_res)
keypoint_vector = []
score_vector = []
......@@ -79,19 +78,20 @@ def topdown_unite_predict(detector,
detector.gpu_util += gu
else:
results = detector.predict_image([image], visual=False)
results = detector.filter_box(results, FLAGS.det_threshold)
if results['boxes_num'] > 0:
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.run_benchmark)
if results['boxes_num'] == 0:
continue
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
else:
results["keypoint"] = [[], []]
keypoint_res = results
if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb()
topdown_keypoint_detector.cpu_mem += cm
......@@ -138,7 +138,7 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0
store_res = []
......@@ -152,10 +152,14 @@ def topdown_unite_predict_video(detector,
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict_image([frame2], visual=False)
results = detector.filter_box(results, FLAGS.det_threshold)
if results['boxes_num'] == 0:
writer.write(frame)
continue
keypoint_res = predict_with_given_det(
frame2, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
FLAGS.run_benchmark)
im = visualize_pose(
frame,
......
......@@ -95,12 +95,10 @@ class KeyPointDetector(Detector):
def set_config(self, model_dir):
return PredictConfig_KeyPoint(model_dir)
def get_person_from_rect(self, image, results, det_threshold=0.5):
def get_person_from_rect(self, image, results):
# crop the person result from image
self.det_times.preprocess_time_s.start()
det_results = results['boxes']
mask = det_results[:, 1] > det_threshold
valid_rects = det_results[mask]
valid_rects = results['boxes']
rect_images = []
new_rects = []
org_rects = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册