diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index b8847e4b909cbab67b2ddb6885b45b73d402de19..9f753478d8ecf12441d4b1745a9f6750a1038e31 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } + need_update_ = true; ops_.erase(ops_.begin() + s, ops_.begin() + e); } diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 77d17fbbccca0292e21acd5e8fa90448527b95c0..16694bcf76486a9603c41dc19a58dd0a7cb2b719 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { return blocks_.back().get(); } -proto::ProgramDesc *ProgramDesc::Proto() { +void ProgramDesc::Flush() { for (auto &block : blocks_) { block->Flush(); } +} + +proto::ProgramDesc *ProgramDesc::Proto() { + Flush(); return &desc_; } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 4288081be72c44c0fc3584b50c41a270eac9e204..65fa0a0cfd5ba6d9b8765cee1309e118cb74348a 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -51,6 +51,8 @@ class ProgramDesc { size_t Size() const { return blocks_.size(); } + void Flush(); + proto::ProgramDesc *Proto(); // The output variable of feed_op is referenced as feed_target. diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 7de7f84a3dc76195d0098d7bb9baf0461aff3575..6471eb3ab7bf05365c0bb2bf68bb74ef9044c527 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) { .def("block", &pd::ProgramDesc::MutableBlock, pybind11::return_value_policy::reference) .def("num_blocks", &pd::ProgramDesc::Size) + .def("flush", &pd::ProgramDesc::Flush) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("serialize_to_string", SerializeMessage) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index bf4d81233d619f368deeb6a5418bf1293ef35c6e..f7f1ca2598a3e679b24fa8d62c52e4f4de788fe2 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -336,18 +336,20 @@ def save_inference_model(dirname, if main_program is None: main_program = default_main_program() + copy_program = main_program if not os.path.isdir(dirname): os.makedirs(dirname) # Clear the is_target information and remove the existed feed and fetch op - global_block = main_program.global_block() + global_block = copy_program.global_block() for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": global_block.remove_op(i) + copy_program.desc.flush() - pruned_program = main_program.prune(targets=target_vars) + pruned_program = copy_program.prune(targets=target_vars) inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars]