未验证 提交 ef8cb8f6 编写于 作者: W whs 提交者: GitHub

Merge pull request #7816 from wanghaoshuang/infer_prog

 Make get_inference_program support for Evaluator.
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import cPickle as pickle import cPickle as pickle
from paddle.v2.fluid.evaluator import Evaluator
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core from . import core
...@@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None): ...@@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None):
main_program = default_main_program() main_program = default_main_program()
if not isinstance(target_vars, list): if not isinstance(target_vars, list):
target_vars = [target_vars] target_vars = [target_vars]
vars = []
pruned_program = main_program.prune(targets=target_vars) for var in target_vars:
if isinstance(var, Evaluator):
vars.append(var.states)
vars.append(var.metrics)
else:
vars.append(var)
pruned_program = main_program.prune(targets=vars)
inference_program = pruned_program.inference_optimize() inference_program = pruned_program.inference_optimize()
return inference_program return inference_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册