From 598035f9859240b468a2296e8f155d605164dcaf Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 19 Apr 2018 01:57:25 +0800 Subject: [PATCH] Fix a bug in save_inference_model and prune when the program is initailized by load_inference_model (#10011) * Fix bug in save_inference_model and prune when the program is initialized by load_inference_program. * Save the transpiled program instead. --- paddle/fluid/framework/op_desc.h | 2 +- paddle/fluid/pybind/protobuf.cc | 3 ++ paddle/fluid/pybind/pybind.cc | 2 +- python/paddle/fluid/framework.py | 6 ++++ python/paddle/fluid/io.py | 29 ++++++------------- .../tests/book/test_image_classification.py | 4 +++ 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 614dd8cd00..cd6777e60a 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -119,7 +119,7 @@ class OpDesc { void InferVarType(BlockDesc *block) const; - void MarkAsTarget() { desc_.set_is_target(true); } + void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); } void Flush(); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 93533e5c9d..7de7f84a3d 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) { .def("block", &pd::ProgramDesc::MutableBlock, pybind11::return_value_policy::reference) .def("num_blocks", &pd::ProgramDesc::Size) + .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) + .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("serialize_to_string", SerializeMessage) .def("parse_from_string", [](pd::ProgramDesc &program_desc, const std::string &data) { @@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) { .def("check_attrs", &pd::OpDesc::CheckAttrs) .def("infer_shape", &pd::OpDesc::InferShape) .def("infer_var_type", &pd::OpDesc::InferVarType) + .def("set_is_target", &pd::OpDesc::SetIsTarget) .def("serialize_to_string", SerializeMessage) .def("block", &pd::OpDesc::Block, pybind11::return_value_policy::reference); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 19bd30d966..64d92cac7e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle. const std::vector> &targets) { ProgramDesc prog_with_targets(origin); for (const auto &t : targets) { - prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget(); + prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); } proto::ProgramDesc pruned_desc; Prune(*prog_with_targets.Proto(), &pruned_desc); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4b841ef31d..5e6c6204c5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1070,6 +1070,12 @@ class Program(object): for t in targets: if not isinstance(t, Operator): if isinstance(t, Variable): + if t.op is None: + global_block = self.global_block() + for op in global_block.ops: + if t.name in op.output_arg_names: + t.op = op + break t = t.op else: raise ValueError(("All targets of prune() can only be " diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 1c0f1f6eb4..bf4d81233d 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -340,6 +340,13 @@ def save_inference_model(dirname, if not os.path.isdir(dirname): os.makedirs(dirname) + # Clear the is_target information and remove the existed feed and fetch op + global_block = main_program.global_block() + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + global_block.remove_op(i) + pruned_program = main_program.prune(targets=target_vars) inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] @@ -362,24 +369,6 @@ def save_inference_model(dirname, save_persistables(executor, dirname, inference_program, params_filename) -def get_feed_targets_names(program): - feed_targets_names = [] - global_block = program.global_block() - for op in global_block.ops: - if op.desc.type() == 'feed': - feed_targets_names.insert(0, op.desc.output('Out')[0]) - return feed_targets_names - - -def get_fetch_targets_names(program): - fetch_targets_names = [] - global_block = program.global_block() - for op in global_block.ops: - if op.desc.type() == 'fetch': - fetch_targets_names.append(op.desc.input('X')[0]) - return fetch_targets_names - - def load_inference_model(dirname, executor, model_filename=None, @@ -418,8 +407,8 @@ def load_inference_model(dirname, program = Program.parse_from_string(program_desc_str) load_persistables(executor, dirname, program, params_filename) - feed_target_names = get_feed_targets_names(program) - fetch_target_names = get_fetch_targets_names(program) + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() fetch_targets = [ program.global_block().var(name) for name in fetch_target_names ] diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 0027b651e8..d3c14b83fa 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None): print("infer results: ", results[0]) + fluid.io.save_inference_model(save_dirname, feed_target_names, + fetch_targets, exe, + inference_transpiler_program) + def main(net_type, use_cuda, is_local=True): if use_cuda and not fluid.core.is_compiled_with_cuda(): -- GitLab