diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 376d6013a38923014fa35e964e58d7f56bf80546..5b02d2495d1ebe9e82e7f847e5bd07548901c7fc 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -15,6 +15,7 @@ import os import cPickle as pickle +from paddle.v2.fluid.evaluator import Evaluator from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable from . import core @@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None): main_program = default_main_program() if not isinstance(target_vars, list): target_vars = [target_vars] - - pruned_program = main_program.prune(targets=target_vars) + vars = [] + for var in target_vars: + if isinstance(var, Evaluator): + vars.append(var.states) + vars.append(var.metrics) + else: + vars.append(var) + pruned_program = main_program.prune(targets=vars) inference_program = pruned_program.inference_optimize() return inference_program