未验证 提交 7d5380fd 编写于 作者: W wangguanzhong 提交者: GitHub

update cpp_infer for paddle latest (#424)

* update cpp_infer for paddle latest

* polish code
上级 ca199f73
...@@ -104,14 +104,16 @@ class Resize(object): ...@@ -104,14 +104,16 @@ class Resize(object):
target_size, target_size,
max_size=0, max_size=0,
interp=cv2.INTER_LINEAR, interp=cv2.INTER_LINEAR,
use_cv2=True): use_cv2=True,
image_shape=None):
super(Resize, self).__init__() super(Resize, self).__init__()
self.target_size = target_size self.target_size = target_size
self.max_size = max_size self.max_size = max_size
self.interp = interp self.interp = interp
self.use_cv2 = use_cv2 self.use_cv2 = use_cv2
self.image_shape = image_shape
def __call__(self, im, use_trt=False): def __call__(self, im):
origin_shape = im.shape[:2] origin_shape = im.shape[:2]
im_c = im.shape[2] im_c = im.shape[2]
if self.max_size != 0: if self.max_size != 0:
...@@ -147,10 +149,7 @@ class Resize(object): ...@@ -147,10 +149,7 @@ class Resize(object):
im = im.resize((int(resize_w), int(resize_h)), self.interp) im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im) im = np.array(im)
# padding im # padding im
if self.max_size != 0 and use_trt: if self.max_size != 0 and self.image_shape is not None:
logger.warning('Due to the limitation of tensorRT, padding the '
'image shape to {} * {}'.format(self.max_size,
self.max_size))
padding_im = np.zeros( padding_im = np.zeros(
(self.max_size, self.max_size, im_c), dtype=np.float32) (self.max_size, self.max_size, im_c), dtype=np.float32)
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
...@@ -189,10 +188,10 @@ class Permute(object): ...@@ -189,10 +188,10 @@ class Permute(object):
def __call__(self, im): def __call__(self, im):
if self.channel_first: if self.channel_first:
im = im.transpose((2, 0, 1)).copy() im = im.transpose((2, 0, 1))
if self.to_bgr: if self.to_bgr:
im = im[[2, 1, 0], :, :] im = im[[2, 1, 0], :, :]
return im return im.copy()
class PadStride(object): class PadStride(object):
...@@ -214,7 +213,7 @@ class PadStride(object): ...@@ -214,7 +213,7 @@ class PadStride(object):
return padding_im return padding_im
def Preprocess(img_path, arch, config, use_trt): def Preprocess(img_path, arch, config):
img = DecodeImage(img_path) img = DecodeImage(img_path)
orig_shape = img.shape orig_shape = img.shape
scale = 1. scale = 1.
...@@ -224,7 +223,7 @@ def Preprocess(img_path, arch, config, use_trt): ...@@ -224,7 +223,7 @@ def Preprocess(img_path, arch, config, use_trt):
obj = data_aug_conf.pop('type') obj = data_aug_conf.pop('type')
preprocess = eval(obj)(**data_aug_conf) preprocess = eval(obj)(**data_aug_conf)
if obj == 'Resize': if obj == 'Resize':
img, scale = preprocess(img, use_trt) img, scale = preprocess(img)
else: else:
img = preprocess(img) img = preprocess(img)
...@@ -509,8 +508,11 @@ def infer(): ...@@ -509,8 +508,11 @@ def infer():
conf = yaml.safe_load(f) conf = yaml.safe_load(f)
use_trt = not conf['use_python_inference'] and 'trt' in conf['mode'] use_trt = not conf['use_python_inference'] and 'trt' in conf['mode']
img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'], if use_trt:
use_trt) logger.warning(
"Due to the limitation of tensorRT, the image shape needs to set in export_model"
)
img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'])
if 'SSD' in conf['arch']: if 'SSD' in conf['arch']:
img_data, res['im_shape'] = img_data img_data, res['im_shape'] = img_data
img_data = [img_data] img_data = [img_data]
......
...@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) ...@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
def parse_reader(reader_cfg, metric, arch): def parse_reader(reader_cfg, metric, arch):
preprocess_list = [] preprocess_list = []
image_shape = reader_cfg['inputs_def'].get('image_shape', [None]) image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None])
has_shape_def = not None in image_shape has_shape_def = not None in image_shape
scale_set = {'RCNN', 'RetinaNet'} scale_set = {'RCNN', 'RetinaNet'}
...@@ -58,9 +58,11 @@ def parse_reader(reader_cfg, metric, arch): ...@@ -58,9 +58,11 @@ def parse_reader(reader_cfg, metric, arch):
params = st.__dict__ params = st.__dict__
params.pop('_id') params.pop('_id')
if p['type'] == 'Resize' and has_shape_def: if p['type'] == 'Resize' and has_shape_def:
params['target_size'] = image_shape[1] params['target_size'] = min(image_shape[
params['max_size'] = image_shape[2] if arch in scale_set else 0 1:]) if arch in scale_set else image_shape[1]
params['max_size'] = max(image_shape[
1:]) if arch in scale_set else 0
params['image_shape'] = image_shape[1:]
p.update(params) p.update(params)
preprocess_list.append(p) preprocess_list.append(p)
batch_transforms = reader_cfg.get('batch_transforms', None) batch_transforms = reader_cfg.get('batch_transforms', None)
...@@ -102,6 +104,7 @@ def dump_infer_config(config): ...@@ -102,6 +104,7 @@ def dump_infer_config(config):
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'] = parse_reader(config['TestReader'], config['metric'], 'label_list'] = parse_reader(config['TestReader'], config['metric'],
infer_cfg['arch']) infer_cfg['arch'])
yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w')) yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w'))
logger.info("Export inference config file to {}".format( logger.info("Export inference config file to {}".format(
os.path.join(save_dir, 'infer_cfg.yml'))) os.path.join(save_dir, 'infer_cfg.yml')))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册