From 7ffbcbcaf0e9414920cf1a4a2a14cdfb45fadd65 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 19 Apr 2018 11:55:23 +0000 Subject: [PATCH] Add flush of program desc to update the proto information. --- paddle/fluid/framework/block_desc.cc | 1 + paddle/fluid/framework/program_desc.cc | 6 +++++- paddle/fluid/framework/program_desc.h | 2 ++ paddle/fluid/pybind/protobuf.cc | 1 + python/paddle/fluid/io.py | 6 ++++-- 5 files changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index b8847e4b90..9f753478d8 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } + need_update_ = true; ops_.erase(ops_.begin() + s, ops_.begin() + e); } diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 77d17fbbcc..16694bcf76 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { return blocks_.back().get(); } -proto::ProgramDesc *ProgramDesc::Proto() { +void ProgramDesc::Flush() { for (auto &block : blocks_) { block->Flush(); } +} + +proto::ProgramDesc *ProgramDesc::Proto() { + Flush(); return &desc_; } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 4288081be7..65fa0a0cfd 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -51,6 +51,8 @@ class ProgramDesc { size_t Size() const { return blocks_.size(); } + void Flush(); + proto::ProgramDesc *Proto(); // The output variable of feed_op is referenced as feed_target. diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 7de7f84a3d..6471eb3ab7 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) { .def("block", &pd::ProgramDesc::MutableBlock, pybind11::return_value_policy::reference) .def("num_blocks", &pd::ProgramDesc::Size) + .def("flush", &pd::ProgramDesc::Flush) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("serialize_to_string", SerializeMessage) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index bf4d81233d..f7f1ca2598 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -336,18 +336,20 @@ def save_inference_model(dirname, if main_program is None: main_program = default_main_program() + copy_program = main_program if not os.path.isdir(dirname): os.makedirs(dirname) # Clear the is_target information and remove the existed feed and fetch op - global_block = main_program.global_block() + global_block = copy_program.global_block() for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": global_block.remove_op(i) + copy_program.desc.flush() - pruned_program = main_program.prune(targets=target_vars) + pruned_program = copy_program.prune(targets=target_vars) inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] -- GitLab