未验证 提交 8e005407 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10058 from Xreki/core_fix_flush

Add flush of program desc to update the proto information.
......@@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
need_update_ = true;
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
......
......@@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
return blocks_.back().get();
}
proto::ProgramDesc *ProgramDesc::Proto() {
void ProgramDesc::Flush() {
for (auto &block : blocks_) {
block->Flush();
}
}
proto::ProgramDesc *ProgramDesc::Proto() {
Flush();
return &desc_;
}
......
......@@ -51,6 +51,8 @@ class ProgramDesc {
size_t Size() const { return blocks_.size(); }
void Flush();
proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target.
......
......@@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) {
.def("block", &pd::ProgramDesc::MutableBlock,
pybind11::return_value_policy::reference)
.def("num_blocks", &pd::ProgramDesc::Size)
.def("flush", &pd::ProgramDesc::Flush)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
......
......@@ -336,18 +336,20 @@ def save_inference_model(dirname,
if main_program is None:
main_program = default_main_program()
copy_program = main_program
if not os.path.isdir(dirname):
os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op
global_block = main_program.global_block()
global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i)
copy_program.desc.flush()
pruned_program = main_program.prune(targets=target_vars)
pruned_program = copy_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册