提交 0b3f8b68 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix save_inference_model cannot find feed vars (#2842)

* fix save_inference_model cannot find feed vars

* remove comment

* fix format

* add comment for pruned var

* add log for prune
上级 7dbffcd6
......@@ -82,13 +82,35 @@ def get_test_images(infer_dir, infer_img):
return images
def prune_feed_vars(feeded_var_names, target_vars, prog):
"""
Filter out feed variables which are not in program,
pruned feed variables are only used in post processing
on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox
in image.
"""
exist_var_names = []
prog = prog.clone()
prog = prog._prune(targets=target_vars)
global_block = prog.global_block()
for name in feeded_var_names:
try:
v = global_block.var(name)
exist_var_names.append(v.name)
except Exception:
logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name))
pass
return exist_var_names
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()]
# im_id is only used for visualize, not used in inference model
feeded_var_names.remove('im_id')
target_vars = test_fetches.values()
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog)
logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册