未验证 提交 59b70495 编写于 作者: W wangguanzhong 提交者: GitHub

fix cpp_infer in SSD (#231)

上级 ef1ed933
...@@ -78,7 +78,9 @@ def get_extra_info(im, arch, shape, scale): ...@@ -78,7 +78,9 @@ def get_extra_info(im, arch, shape, scale):
logger.info('Extra info: im_size') logger.info('Extra info: im_size')
info.append(im_size) info.append(im_size)
elif 'SSD' in arch: elif 'SSD' in arch:
pass im_shape = np.array([shape[:2]]).astype('int32')
logger.info('Extra info: im_shape')
info.append([im_shape])
elif 'RetinaNet' in arch: elif 'RetinaNet' in arch:
input_shape.extend(im.shape[2:]) input_shape.extend(im.shape[2:])
im_info = np.array([input_shape + [scale]]).astype('float32') im_info = np.array([input_shape + [scale]]).astype('float32')
...@@ -190,6 +192,7 @@ def Preprocess(img_path, arch, config): ...@@ -190,6 +192,7 @@ def Preprocess(img_path, arch, config):
def infer(): def infer():
model_path = FLAGS.model_path model_path = FLAGS.model_path
config_path = FLAGS.config_path config_path = FLAGS.config_path
res = {}
assert model_path is not None, "Model path: {} does not exist!".format( assert model_path is not None, "Model path: {} does not exist!".format(
model_path) model_path)
assert config_path is not None, "Config path: {} does not exist!".format( assert config_path is not None, "Config path: {} does not exist!".format(
...@@ -198,6 +201,9 @@ def infer(): ...@@ -198,6 +201,9 @@ def infer():
conf = yaml.safe_load(f) conf = yaml.safe_load(f)
img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'])
if 'SSD' in conf['arch']:
img_data, res['im_shape'] = img_data
img_data = [img_data]
if conf['use_python_inference']: if conf['use_python_inference']:
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
...@@ -253,7 +259,6 @@ def infer(): ...@@ -253,7 +259,6 @@ def infer():
is_bbox_normalized = True if 'SSD' in conf['arch'] else False is_bbox_normalized = True if 'SSD' in conf['arch'] else False
out = outs[-1] out = outs[-1]
res = {}
lod = out.lod() if conf['use_python_inference'] else out.lod lod = out.lod() if conf['use_python_inference'] else out.lod
lengths = offset_to_lengths(lod) lengths = offset_to_lengths(lod)
np_data = np.array(out) if conf[ np_data = np.array(out) if conf[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册