未验证 提交 a229530b 编写于 作者: Z zhiboniu 提交者: GitHub

add deploy keypoint infer save results (#4480)

上级 8a87d99e
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
import json
import cv2
import math
import numpy as np
......@@ -80,7 +80,7 @@ def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector)
np.vstack(keypoint_vector).tolist(), np.vstack(score_vector).tolist()
] if len(keypoint_vector) > 0 else [[], []]
keypoint_res['bbox'] = rect_vector
return keypoint_res
......@@ -89,8 +89,10 @@ def predict_with_given_det(image, det_res, keypoint_detector,
def topdown_unite_predict(detector,
topdown_keypoint_detector,
image_list,
keypoint_batch_size=1):
keypoint_batch_size=1,
save_res=False):
det_timer = detector.get_timer()
store_res = []
for i, img_file in enumerate(image_list):
# Decode image in advance in det + pose prediction
det_timer.preprocess_time_s.start()
......@@ -114,6 +116,11 @@ def topdown_unite_predict(detector,
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb()
topdown_keypoint_detector.cpu_mem += cm
......@@ -127,12 +134,23 @@ def topdown_unite_predict(detector,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir)
if save_res:
"""
1) store_res: a list of image_data
2) image_data: [imageid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_image_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def topdown_unite_predict_video(detector,
topdown_keypoint_detector,
camera_id,
keypoint_batch_size=1):
keypoint_batch_size=1,
save_res=False):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
......@@ -150,9 +168,10 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0
store_res = []
while (1):
ret, frame = capture.read()
if not ret:
......@@ -172,6 +191,11 @@ def topdown_unite_predict_video(detector,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True)
if save_res:
store_res.append([
index, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
writer.write(im)
if camera_id != -1:
......@@ -179,6 +203,16 @@ def topdown_unite_predict_video(detector,
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
if save_res:
"""
1) store_res: a list of frame_data
2) frame_data: [frameid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_video_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def main():
......@@ -216,12 +250,13 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
topdown_unite_predict_video(detector, topdown_keypoint_detector,
FLAGS.camera_id, FLAGS.keypoint_batch_size)
FLAGS.camera_id, FLAGS.keypoint_batch_size,
FLAGS.save_res)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
topdown_unite_predict(detector, topdown_keypoint_detector, img_list,
FLAGS.keypoint_batch_size)
FLAGS.keypoint_batch_size, FLAGS.save_res)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
topdown_keypoint_detector.det_times.info(average=True)
......
......@@ -115,5 +115,15 @@ def argsparser():
type=bool,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
'--save_res',
type=bool,
default=False,
help=(
"whether to save predict results to json file"
"1) store_res: a list of image_data"
"2) image_data: [imageid, rects, [keypoints, scores]]"
"3) rects: list of rect [xmin, ymin, xmax, ymax]"
"4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list"
"5) scores: mean of all joint conf"))
return parser
......@@ -240,6 +240,7 @@ def draw_pose(imgfile,
raise e
skeletons, scores = results['keypoint']
skeletons = np.array(skeletons)
kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = skeletons.shape[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册