From 7d5380fdb9d1283ce85f3c8449e1344ef4c18c2e Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 2 Apr 2020 21:23:06 +0800 Subject: [PATCH] update cpp_infer for paddle latest (#424) * update cpp_infer for paddle latest * polish code --- tools/cpp_infer.py | 26 ++++++++++++++------------ tools/export_model.py | 11 +++++++---- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tools/cpp_infer.py b/tools/cpp_infer.py index 2dd2cce44..28d0e4790 100644 --- a/tools/cpp_infer.py +++ b/tools/cpp_infer.py @@ -104,14 +104,16 @@ class Resize(object): target_size, max_size=0, interp=cv2.INTER_LINEAR, - use_cv2=True): + use_cv2=True, + image_shape=None): super(Resize, self).__init__() self.target_size = target_size self.max_size = max_size self.interp = interp self.use_cv2 = use_cv2 + self.image_shape = image_shape - def __call__(self, im, use_trt=False): + def __call__(self, im): origin_shape = im.shape[:2] im_c = im.shape[2] if self.max_size != 0: @@ -147,10 +149,7 @@ class Resize(object): im = im.resize((int(resize_w), int(resize_h)), self.interp) im = np.array(im) # padding im - if self.max_size != 0 and use_trt: - logger.warning('Due to the limitation of tensorRT, padding the ' - 'image shape to {} * {}'.format(self.max_size, - self.max_size)) + if self.max_size != 0 and self.image_shape is not None: padding_im = np.zeros( (self.max_size, self.max_size, im_c), dtype=np.float32) im_h, im_w = im.shape[:2] @@ -189,10 +188,10 @@ class Permute(object): def __call__(self, im): if self.channel_first: - im = im.transpose((2, 0, 1)).copy() + im = im.transpose((2, 0, 1)) if self.to_bgr: im = im[[2, 1, 0], :, :] - return im + return im.copy() class PadStride(object): @@ -214,7 +213,7 @@ class PadStride(object): return padding_im -def Preprocess(img_path, arch, config, use_trt): +def Preprocess(img_path, arch, config): img = DecodeImage(img_path) orig_shape = img.shape scale = 1. @@ -224,7 +223,7 @@ def Preprocess(img_path, arch, config, use_trt): obj = data_aug_conf.pop('type') preprocess = eval(obj)(**data_aug_conf) if obj == 'Resize': - img, scale = preprocess(img, use_trt) + img, scale = preprocess(img) else: img = preprocess(img) @@ -509,8 +508,11 @@ def infer(): conf = yaml.safe_load(f) use_trt = not conf['use_python_inference'] and 'trt' in conf['mode'] - img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'], - use_trt) + if use_trt: + logger.warning( + "Due to the limitation of tensorRT, the image shape needs to set in export_model" + ) + img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) if 'SSD' in conf['arch']: img_data, res['im_shape'] = img_data img_data = [img_data] diff --git a/tools/export_model.py b/tools/export_model.py index e2e166f91..c6b632bd5 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) def parse_reader(reader_cfg, metric, arch): preprocess_list = [] - image_shape = reader_cfg['inputs_def'].get('image_shape', [None]) + image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None]) has_shape_def = not None in image_shape scale_set = {'RCNN', 'RetinaNet'} @@ -58,9 +58,11 @@ def parse_reader(reader_cfg, metric, arch): params = st.__dict__ params.pop('_id') if p['type'] == 'Resize' and has_shape_def: - params['target_size'] = image_shape[1] - params['max_size'] = image_shape[2] if arch in scale_set else 0 - + params['target_size'] = min(image_shape[ + 1:]) if arch in scale_set else image_shape[1] + params['max_size'] = max(image_shape[ + 1:]) if arch in scale_set else 0 + params['image_shape'] = image_shape[1:] p.update(params) preprocess_list.append(p) batch_transforms = reader_cfg.get('batch_transforms', None) @@ -102,6 +104,7 @@ def dump_infer_config(config): infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ 'label_list'] = parse_reader(config['TestReader'], config['metric'], infer_cfg['arch']) + yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w')) logger.info("Export inference config file to {}".format( os.path.join(save_dir, 'infer_cfg.yml'))) -- GitLab