提交 283b2581 编写于 作者: W wangguanzhong 提交者: GitHub

fix bug in py3 infer (#2913)

上级 a48e7e91
...@@ -109,17 +109,19 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): ...@@ -109,17 +109,19 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name) save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()] feeded_var_names = [var.name for var in feed_vars.values()]
target_vars = test_fetches.values() target_vars = list(test_fetches.values())
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) feeded_var_names = prune_feed_vars(feeded_var_names, target_vars,
infer_prog)
logger.info("Save inference model to {}, input: {}, output: " logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names, "{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars])) [var.name for var in target_vars]))
fluid.io.save_inference_model(save_dir, fluid.io.save_inference_model(
feeded_var_names=feeded_var_names, save_dir,
target_vars=target_vars, feeded_var_names=feeded_var_names,
executor=exe, target_vars=target_vars,
main_program=infer_prog, executor=exe,
params_filename="__params__") main_program=infer_prog,
params_filename="__params__")
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册