You need to sign in or sign up before continuing.
未验证 提交 8e005407 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10058 from Xreki/core_fix_flush

Add flush of program desc to update the proto information.
...@@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return; return;
} }
need_update_ = true;
ops_.erase(ops_.begin() + s, ops_.begin() + e); ops_.erase(ops_.begin() + s, ops_.begin() + e);
} }
......
...@@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { ...@@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
return blocks_.back().get(); return blocks_.back().get();
} }
proto::ProgramDesc *ProgramDesc::Proto() { void ProgramDesc::Flush() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
}
proto::ProgramDesc *ProgramDesc::Proto() {
Flush();
return &desc_; return &desc_;
} }
......
...@@ -51,6 +51,8 @@ class ProgramDesc { ...@@ -51,6 +51,8 @@ class ProgramDesc {
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
void Flush();
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
......
...@@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) {
.def("block", &pd::ProgramDesc::MutableBlock, .def("block", &pd::ProgramDesc::MutableBlock,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("num_blocks", &pd::ProgramDesc::Size) .def("num_blocks", &pd::ProgramDesc::Size)
.def("flush", &pd::ProgramDesc::Flush)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>) .def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
......
...@@ -336,18 +336,20 @@ def save_inference_model(dirname, ...@@ -336,18 +336,20 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
copy_program = main_program
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op # Clear the is_target information and remove the existed feed and fetch op
global_block = main_program.global_block() global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops): for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False) op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch": if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i) global_block.remove_op(i)
copy_program.desc.flush()
pruned_program = main_program.prune(targets=target_vars) pruned_program = copy_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize() inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars] fetch_var_names = [v.name for v in target_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册