提交 28e19603 编写于 作者: W wangguanzhong 提交者: GitHub

Update prune interface (#3336)

上级 0be30344
...@@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog): ...@@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog):
""" """
exist_var_names = [] exist_var_names = []
prog = prog.clone() prog = prog.clone()
prog = prog._prune(targets=target_vars) prog = prog._prune(feeded_var_names, targets=target_vars)
global_block = prog.global_block() global_block = prog.global_block()
for name in feeded_var_names: for name in feeded_var_names:
try: try:
v = global_block.var(name) v = global_block.var(name)
exist_var_names.append(v.name) exist_var_names.append(v.name.encode('utf-8'))
except Exception: except Exception:
logger.info('save_inference_model pruned unused feed ' logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name)) 'variables {}'.format(name))
...@@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): ...@@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, feeded_var_names = prune_feed_vars(feeded_var_names, target_vars,
infer_prog) infer_prog)
logger.info("Save inference model to {}, input: {}, output: " logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names, "{}...".format(save_dir, feeded_var_names, [
[var.name for var in target_vars])) var.name.encode('utf-8') for var in target_vars
]))
fluid.io.save_inference_model( fluid.io.save_inference_model(
save_dir, save_dir,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册