未验证 提交 a229530b 编写于 作者: Z zhiboniu 提交者: GitHub

add deploy keypoint infer save results (#4480)

上级 8a87d99e
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import json
import cv2 import cv2
import math import math
import numpy as np import numpy as np
...@@ -80,7 +80,7 @@ def predict_with_given_det(image, det_res, keypoint_detector, ...@@ -80,7 +80,7 @@ def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_res = {} keypoint_res = {}
keypoint_res['keypoint'] = [ keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector) np.vstack(keypoint_vector).tolist(), np.vstack(score_vector).tolist()
] if len(keypoint_vector) > 0 else [[], []] ] if len(keypoint_vector) > 0 else [[], []]
keypoint_res['bbox'] = rect_vector keypoint_res['bbox'] = rect_vector
return keypoint_res return keypoint_res
...@@ -89,8 +89,10 @@ def predict_with_given_det(image, det_res, keypoint_detector, ...@@ -89,8 +89,10 @@ def predict_with_given_det(image, det_res, keypoint_detector,
def topdown_unite_predict(detector, def topdown_unite_predict(detector,
topdown_keypoint_detector, topdown_keypoint_detector,
image_list, image_list,
keypoint_batch_size=1): keypoint_batch_size=1,
save_res=False):
det_timer = detector.get_timer() det_timer = detector.get_timer()
store_res = []
for i, img_file in enumerate(image_list): for i, img_file in enumerate(image_list):
# Decode image in advance in det + pose prediction # Decode image in advance in det + pose prediction
det_timer.preprocess_time_s.start() det_timer.preprocess_time_s.start()
...@@ -114,6 +116,11 @@ def topdown_unite_predict(detector, ...@@ -114,6 +116,11 @@ def topdown_unite_predict(detector,
image, results, topdown_keypoint_detector, keypoint_batch_size, image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
topdown_keypoint_detector.cpu_mem += cm topdown_keypoint_detector.cpu_mem += cm
...@@ -127,12 +134,23 @@ def topdown_unite_predict(detector, ...@@ -127,12 +134,23 @@ def topdown_unite_predict(detector,
keypoint_res, keypoint_res,
visual_thread=FLAGS.keypoint_threshold, visual_thread=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir) save_dir=FLAGS.output_dir)
if save_res:
"""
1) store_res: a list of image_data
2) image_data: [imageid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_image_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def topdown_unite_predict_video(detector, def topdown_unite_predict_video(detector,
topdown_keypoint_detector, topdown_keypoint_detector,
camera_id, camera_id,
keypoint_batch_size=1): keypoint_batch_size=1,
save_res=False):
video_name = 'output.mp4' video_name = 'output.mp4'
if camera_id != -1: if camera_id != -1:
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
...@@ -150,9 +168,10 @@ def topdown_unite_predict_video(detector, ...@@ -150,9 +168,10 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0 index = 0
store_res = []
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -172,6 +191,11 @@ def topdown_unite_predict_video(detector, ...@@ -172,6 +191,11 @@ def topdown_unite_predict_video(detector,
keypoint_res, keypoint_res,
visual_thread=FLAGS.keypoint_threshold, visual_thread=FLAGS.keypoint_threshold,
returnimg=True) returnimg=True)
if save_res:
store_res.append([
index, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
...@@ -179,6 +203,16 @@ def topdown_unite_predict_video(detector, ...@@ -179,6 +203,16 @@ def topdown_unite_predict_video(detector,
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
writer.release() writer.release()
if save_res:
"""
1) store_res: a list of frame_data
2) frame_data: [frameid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_video_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def main(): def main():
...@@ -216,12 +250,13 @@ def main(): ...@@ -216,12 +250,13 @@ def main():
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
topdown_unite_predict_video(detector, topdown_keypoint_detector, topdown_unite_predict_video(detector, topdown_keypoint_detector,
FLAGS.camera_id, FLAGS.keypoint_batch_size) FLAGS.camera_id, FLAGS.keypoint_batch_size,
FLAGS.save_res)
else: else:
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
topdown_unite_predict(detector, topdown_keypoint_detector, img_list, topdown_unite_predict(detector, topdown_keypoint_detector, img_list,
FLAGS.keypoint_batch_size) FLAGS.keypoint_batch_size, FLAGS.save_res)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
topdown_keypoint_detector.det_times.info(average=True) topdown_keypoint_detector.det_times.info(average=True)
......
...@@ -115,5 +115,15 @@ def argsparser(): ...@@ -115,5 +115,15 @@ def argsparser():
type=bool, type=bool,
default=True, default=True,
help='whether to use darkpose to get better keypoint position predict ') help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
'--save_res',
type=bool,
default=False,
help=(
"whether to save predict results to json file"
"1) store_res: a list of image_data"
"2) image_data: [imageid, rects, [keypoints, scores]]"
"3) rects: list of rect [xmin, ymin, xmax, ymax]"
"4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list"
"5) scores: mean of all joint conf"))
return parser return parser
...@@ -240,6 +240,7 @@ def draw_pose(imgfile, ...@@ -240,6 +240,7 @@ def draw_pose(imgfile,
raise e raise e
skeletons, scores = results['keypoint'] skeletons, scores = results['keypoint']
skeletons = np.array(skeletons)
kpt_nums = 17 kpt_nums = 17
if len(skeletons) > 0: if len(skeletons) > 0:
kpt_nums = skeletons.shape[1] kpt_nums = skeletons.shape[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册