From 3ba7e74db2fa80d6130c1dec5a5f64dec9d5a784 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 20 Sep 2018 10:32:25 +0800 Subject: [PATCH] use clone(for_test=True) replace get_inference_program --- paddle/fluid/API.spec | 1 - python/paddle/fluid/io.py | 20 +------------------ .../fluid/tests/unittests/dist_transformer.py | 4 +--- 3 files changed, 2 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f61d1254f..61a99885c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -73,7 +73,6 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)) -paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 78bb8a1a0..e703e5ac7 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -27,8 +27,7 @@ from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', - 'load_persistables', 'save_inference_model', 'load_inference_model', - 'get_inference_program' + 'load_persistables', 'save_inference_model', 'load_inference_model' ] @@ -504,23 +503,6 @@ def load_persistables(executor, dirname, main_program=None, filename=None): filename=filename) -def get_inference_program(target_vars, main_program=None): - if main_program is None: - main_program = default_main_program() - if not isinstance(target_vars, list): - target_vars = [target_vars] - vars = [] - for var in target_vars: - if isinstance(var, Evaluator): - vars.extend(var.states) - vars.extend(var.metrics) - else: - vars.append(var) - pruned_program = main_program._prune(targets=vars) - inference_program = pruned_program._inference_optimize() - return inference_program - - def prepend_feed_ops(inference_program, feed_target_names, feed_holder_name='feed'): diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index 3ec79f8ef..c652d6076 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -440,9 +440,7 @@ def split_data(data, num_part): def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num): # Context to do validation. - test_program = train_progm.clone() - with fluid.program_guard(test_program): - test_program = fluid.io.get_inference_program([avg_cost]) + test_program = train_progm.clone(for_test=True) val_data = DataReader( src_vocab_fpath=TrainTaskConfig.src_vocab_fpath, -- GitLab