From 5f9b0bc3d5d6fd9d8152545b2dbdead417c75e00 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 28 Jun 2021 11:21:56 +0800 Subject: [PATCH] refine keypoint deploy (#3473) * refine keypoint deploy * fit video infer * fix post process * update comments for keypoint_batch_size --- deploy/python/infer.py | 7 +- deploy/python/keypoint_det_unite_infer.py | 169 ++++++++++++++-------- deploy/python/keypoint_infer.py | 63 +++++--- deploy/python/keypoint_preprocess.py | 18 +++ deploy/python/topdown_unite_utils.py | 7 + deploy/python/utils.py | 2 +- ppdet/modeling/post_process.py | 3 +- 7 files changed, 189 insertions(+), 80 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 2ea06c9b5..d07128d2b 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -171,6 +171,9 @@ class Detector(object): self.det_times.img_num += len(image_list) return results + def get_timer(self): + return self.det_times + class DetectorSOLOv2(Detector): """ @@ -269,8 +272,8 @@ class DetectorSOLOv2(Detector): def create_inputs(imgs, im_info): """generate input for different model type Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image + imgs (list(numpy)): list of images (np.ndarray) + im_info (list(dict)): list of image info Returns: inputs (dict): input of model """ diff --git a/deploy/python/keypoint_det_unite_infer.py b/deploy/python/keypoint_det_unite_infer.py index a9b0ea69b..056821e5c 100644 --- a/deploy/python/keypoint_det_unite_infer.py +++ b/deploy/python/keypoint_det_unite_infer.py @@ -15,6 +15,7 @@ import os from PIL import Image import cv2 +import math import numpy as np import paddle @@ -23,80 +24,107 @@ from preprocess import decode_image from infer import Detector, PredictConfig, print_arguments, get_test_images from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint from keypoint_visualize import draw_pose +from benchmark_utils import PaddleInferBenchmark +from utils import get_current_memory_mb -def expand_crop(images, rect, expand_ratio=0.3): - imgh, imgw, c = images.shape - label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] - if label != 0: - return None, None, None - org_rect = [xmin, ymin, xmax, ymax] - h_half = (ymax - ymin) * (1 + expand_ratio) / 2. - w_half = (xmax - xmin) * (1 + expand_ratio) / 2. - if h_half > w_half * 4 / 3: - w_half = h_half * 0.75 - center = [(ymin + ymax) / 2., (xmin + xmax) / 2.] - ymin = max(0, int(center[0] - h_half)) - ymax = min(imgh - 1, int(center[0] + h_half)) - xmin = max(0, int(center[1] - w_half)) - xmax = min(imgw - 1, int(center[1] + w_half)) - return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect - - -def get_person_from_rect(images, results): - det_results = results['boxes'] - mask = det_results[:, 1] > FLAGS.det_threshold - valid_rects = det_results[mask] - image_buff = [] - org_rects = [] - for rect in valid_rects: - rect_image, new_rect, org_rect = expand_crop(images, rect) - if rect_image is None or rect_image.size == 0: - continue - image_buff.append([rect_image, new_rect]) - org_rects.append(org_rect) - return image_buff, org_rects +def bench_log(detector, img_list, model_info, batch_size=1, name=None): + mems = { + 'cpu_rss_mb': detector.cpu_mem / len(img_list), + 'gpu_rss_mb': detector.gpu_mem / len(img_list), + 'gpu_util': detector.gpu_util * 100 / len(img_list) + } + perf_info = detector.det_times.report(average=True) + data_info = { + 'batch_size': batch_size, + 'shape': "dynamic_shape", + 'data_num': perf_info['img_num'] + } + + log = PaddleInferBenchmark(detector.config, model_info, data_info, + perf_info, mems) + log(name) def affine_backto_orgimages(keypoint_result, batch_records): kpts, scores = keypoint_result['keypoint'] - kpts[..., 0] += batch_records[0] - kpts[..., 1] += batch_records[1] + kpts[..., 0] += batch_records[:, 0:1] + kpts[..., 1] += batch_records[:, 1:2] return kpts, scores -def topdown_unite_predict(detector, topdown_keypoint_detector, image_list): +def topdown_unite_predict(detector, + topdown_keypoint_detector, + image_list, + keypoint_batch_size=1): + det_timer = detector.get_timer() for i, img_file in enumerate(image_list): + # Decode image in advance in det + pose prediction + det_timer.preprocess_time_s.start() image, _ = decode_image(img_file, {}) - results = detector.predict([image], FLAGS.det_threshold) + det_timer.preprocess_time_s.end() + + if FLAGS.run_benchmark: + results = detector.predict( + [image], FLAGS.det_threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + detector.cpu_mem += cm + detector.gpu_mem += gm + detector.gpu_util += gu + else: + results = detector.predict([image], FLAGS.det_threshold) + if results['boxes_num'] == 0: continue - batchs_images, det_rects = get_person_from_rect(image, results) + rec_images, records, det_rects = topdown_keypoint_detector.get_person_from_rect( + image, results, FLAGS.det_threshold) keypoint_vector = [] score_vector = [] - rect_vecotr = det_rects - for batch_images, batch_records in batchs_images: - keypoint_result = topdown_keypoint_detector.predict( - batch_images, FLAGS.keypoint_threshold) + rect_vector = det_rects + batch_loop_cnt = math.ceil(float(len(rec_images)) / keypoint_batch_size) + + for i in range(batch_loop_cnt): + start_index = i * keypoint_batch_size + end_index = min((i + 1) * keypoint_batch_size, len(rec_images)) + batch_images = rec_images[start_index:end_index] + batch_records = np.array(records[start_index:end_index]) + if FLAGS.run_benchmark: + keypoint_result = topdown_keypoint_detector.predict( + batch_images, + FLAGS.keypoint_threshold, + warmup=10, + repeats=10) + else: + keypoint_result = topdown_keypoint_detector.predict( + batch_images, FLAGS.keypoint_threshold) orgkeypoints, scores = affine_backto_orgimages(keypoint_result, batch_records) keypoint_vector.append(orgkeypoints) score_vector.append(scores) - keypoint_res = {} - keypoint_res['keypoint'] = [ - np.vstack(keypoint_vector), np.vstack(score_vector) - ] - keypoint_res['bbox'] = rect_vecotr - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) - draw_pose( - img_file, - keypoint_res, - visual_thread=FLAGS.keypoint_threshold, - save_dir=FLAGS.output_dir) + if FLAGS.run_benchmark: + cm, gm, gu = get_current_memory_mb() + topdown_keypoint_detector.cpu_mem += cm + topdown_keypoint_detector.gpu_mem += gm + topdown_keypoint_detector.gpu_util += gu + else: + keypoint_res = {} + keypoint_res['keypoint'] = [ + np.vstack(keypoint_vector), np.vstack(score_vector) + ] + keypoint_res['bbox'] = rect_vector + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + draw_pose( + img_file, + keypoint_res, + visual_thread=FLAGS.keypoint_threshold, + save_dir=FLAGS.output_dir) -def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): +def topdown_unite_predict_video(detector, + topdown_keypoint_detector, + camera_id, + keypoint_batch_size=1): if camera_id != -1: capture = cv2.VideoCapture(camera_id) video_name = 'output.mp4' @@ -124,10 +152,16 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = detector.predict([frame2], FLAGS.det_threshold) - batchs_images, rect_vecotr = get_person_from_rect(frame2, results) + rec_images, records, rect_vector = topdown_keypoint_detector.get_person_from_rect( + frame2, results) keypoint_vector = [] score_vector = [] - for batch_images, batch_records in batchs_images: + batch_loop_cnt = math.ceil(float(len(rec_images)) / keypoint_batch_size) + for i in range(batch_loop_cnt): + start_index = i * keypoint_batch_size + end_index = min((i + 1) * keypoint_batch_size, len(rec_images)) + batch_images = rec_images[start_index:end_index] + batch_records = np.array(records[start_index:end_index]) keypoint_result = topdown_keypoint_detector.predict( batch_images, FLAGS.keypoint_threshold) orgkeypoints, scores = affine_backto_orgimages(keypoint_result, @@ -138,7 +172,7 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): keypoint_res['keypoint'] = [ np.vstack(keypoint_vector), np.vstack(score_vector) ] if len(keypoint_vector) > 0 else [[], []] - keypoint_res['bbox'] = rect_vecotr + keypoint_res['bbox'] = rect_vector im = draw_pose( frame, keypoint_res, @@ -184,11 +218,30 @@ def main(): # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: topdown_unite_predict_video(detector, topdown_keypoint_detector, - FLAGS.camera_id) + FLAGS.camera_id, FLAGS.keypoint_batch_size) else: # predict from image img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - topdown_unite_predict(detector, topdown_keypoint_detector, img_list) + topdown_unite_predict(detector, topdown_keypoint_detector, img_list, + FLAGS.keypoint_batch_size) + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + topdown_keypoint_detector.det_times.info(average=True) + else: + mode = FLAGS.run_mode + det_model_dir = FLAGS.det_model_dir + det_model_info = { + 'model_name': det_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(detector, img_list, det_model_info, name='Det') + keypoint_model_dir = FLAGS.keypoint_model_dir + keypoint_model_info = { + 'model_name': keypoint_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(topdown_keypoint_detector, img_list, keypoint_model_info, + FLAGS.keypoint_batch_size, 'KeyPoint') if __name__ == '__main__': diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index 981ccd080..b3a9c9a37 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -20,10 +20,11 @@ from functools import reduce from PIL import Image import cv2 +import math import numpy as np import paddle from preprocess import preprocess, NormalizeImage, Permute -from keypoint_preprocess import EvalAffine, TopDownEvalAffine +from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess from keypoint_visualize import draw_pose from paddle.inference import Config @@ -82,14 +83,41 @@ class KeyPoint_Detector(object): self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 self.use_dark = use_dark - def preprocess(self, im): + def get_person_from_rect(self, image, results, det_threshold=0.5): + # crop the person result from image + self.det_times.preprocess_time_s.start() + det_results = results['boxes'] + mask = det_results[:, 1] > det_threshold + valid_rects = det_results[mask] + rect_images = [] + new_rects = [] + #image_buff = [] + org_rects = [] + for rect in valid_rects: + rect_image, new_rect, org_rect = expand_crop(image, rect) + if rect_image is None or rect_image.size == 0: + continue + #image_buff.append([rect_image, new_rect]) + rect_images.append(rect_image) + new_rects.append(new_rect) + org_rects.append(org_rect) + self.det_times.preprocess_time_s.end() + return rect_images, new_rects, org_rects + + def preprocess(self, image_list): preprocess_ops = [] for op_info in self.pred_config.preprocess_infos: new_op_info = op_info.copy() op_type = new_op_info.pop('type') preprocess_ops.append(eval(op_type)(**new_op_info)) - im, im_info = preprocess(im, preprocess_ops) - inputs = create_inputs(im, im_info) + + input_im_lst = [] + input_im_info_lst = [] + for im in image_list: + im, im_info = preprocess(im, preprocess_ops) + input_im_lst.append(im) + input_im_info_lst.append(im_info) + inputs = create_inputs(input_im_lst, input_im_info_lst) return inputs def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5): @@ -118,10 +146,10 @@ class KeyPoint_Detector(object): raise ValueError("Unsupported arch: {}, expect {}".format( self.pred_config.arch, KEYPOINT_SUPPORT_MODELS)) - def predict(self, image, threshold=0.5, warmup=0, repeats=1): + def predict(self, image_list, threshold=0.5, warmup=0, repeats=1): ''' Args: - image (str/np.ndarray): path of image/ np.ndarray read by cv2 + image_list (list): list of image threshold (float): threshold of predicted box' score Returns: results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, @@ -130,7 +158,7 @@ class KeyPoint_Detector(object): shape: [N, im_h, im_w] ''' self.det_times.preprocess_time_s.start() - inputs = self.preprocess(image) + inputs = self.preprocess(image_list) np_boxes, np_masks = None, None input_names = self.predictor.get_input_names() @@ -172,23 +200,24 @@ class KeyPoint_Detector(object): results = self.postprocess( np_boxes, np_masks, inputs, threshold=threshold) self.det_times.postprocess_time_s.end() - self.det_times.img_num += 1 + self.det_times.img_num += len(image_list) return results -def create_inputs(im, im_info): +def create_inputs(imgs, im_info): """generate input for different model type Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - model_arch (str): model type + imgs (list(numpy)): list of image (np.ndarray) + im_info (list(dict)): list of image info Returns: inputs (dict): input of model """ inputs = {} - inputs['image'] = np.array((im, )).astype('float32') - inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32') - + inputs['image'] = np.stack(imgs, axis=0) + im_shape = [] + for e in im_info: + im_shape.append(np.array((e['im_shape'])).astype('float32')) + inputs['im_shape'] = np.stack(im_shape, axis=0) return inputs @@ -326,14 +355,14 @@ def load_predictor(model_dir, def predict_image(detector, image_list): for i, img_file in enumerate(image_list): if FLAGS.run_benchmark: - detector.predict(img_file, FLAGS.threshold, warmup=10, repeats=10) + detector.predict([img_file], FLAGS.threshold, warmup=10, repeats=10) cm, gm, gu = get_current_memory_mb() detector.cpu_mem += cm detector.gpu_mem += gm detector.gpu_util += gu print('Test iter {}, file name:{}'.format(i, img_file)) else: - results = detector.predict(img_file, FLAGS.threshold) + results = detector.predict([img_file], FLAGS.threshold) if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) draw_pose( diff --git a/deploy/python/keypoint_preprocess.py b/deploy/python/keypoint_preprocess.py index 345f2d7c2..6619c7db5 100644 --- a/deploy/python/keypoint_preprocess.py +++ b/deploy/python/keypoint_preprocess.py @@ -176,3 +176,21 @@ class TopDownEvalAffine(object): flags=cv2.INTER_LINEAR) return image, im_info + + +def expand_crop(images, rect, expand_ratio=0.3): + imgh, imgw, c = images.shape + label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] + if label != 0: + return None, None, None + org_rect = [xmin, ymin, xmax, ymax] + h_half = (ymax - ymin) * (1 + expand_ratio) / 2. + w_half = (xmax - xmin) * (1 + expand_ratio) / 2. + if h_half > w_half * 4 / 3: + w_half = h_half * 0.75 + center = [(ymin + ymax) / 2., (xmin + xmax) / 2.] + ymin = max(0, int(center[0] - h_half)) + ymax = min(imgh - 1, int(center[0] + h_half)) + xmin = max(0, int(center[1] - w_half)) + xmax = min(imgw - 1, int(center[1] + w_half)) + return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect diff --git a/deploy/python/topdown_unite_utils.py b/deploy/python/topdown_unite_utils.py index 6f7b63df6..1af963c06 100644 --- a/deploy/python/topdown_unite_utils.py +++ b/deploy/python/topdown_unite_utils.py @@ -39,6 +39,13 @@ def argsparser(): type=str, default=None, help="Dir of image file, `image_file` has a higher priority.") + parser.add_argument( + "--keypoint_batch_size", + type=int, + default=1, + help=("batch_size for keypoint inference. In detection-keypoint unit" + "inference, the batch size in detection is 1. Then collate det " + "result in batch for keypoint inference.")) parser.add_argument( "--video_file", type=str, diff --git a/deploy/python/utils.py b/deploy/python/utils.py index 35ad43714..a9f48cf28 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -35,7 +35,7 @@ def argsparser(): default=None, help="Dir of image file, `image_file` has a higher priority.") parser.add_argument( - "--batch_size", type=int, default=1, help="batch_size for infer.") + "--batch_size", type=int, default=1, help="batch_size for inference.") parser.add_argument( "--video_file", type=str, diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index df0b467de..80eef2547 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -46,8 +46,7 @@ class BBoxPostProcess(nn.Layer): self.nms = nms self.fake_bboxes = paddle.to_tensor( np.array( - [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], - dtype='float32')) + [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) def forward(self, head_out, rois, im_shape, scale_factor): -- GitLab