提交 a6a79c35 编写于 作者: W wanghaoshuang

More general implementation.

上级 5ecbba46
...@@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None): ...@@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None):
def get_inference_program(target_vars, main_program=None): def get_inference_program(target_vars, main_program=None):
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if isinstance(target_vars, Evaluator):
target_vars = target_vars.states + target_vars.metrics
if not isinstance(target_vars, list): if not isinstance(target_vars, list):
target_vars = [target_vars] target_vars = [target_vars]
vars = []
pruned_program = main_program.prune(targets=target_vars) for var in target_vars:
if isinstance(var, Evaluator):
vars.append(var.states)
vars.append(var.metrics)
else:
vars.append(var)
pruned_program = main_program.prune(targets=vars)
inference_program = pruned_program.inference_optimize() inference_program = pruned_program.inference_optimize()
return inference_program return inference_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册