提交 598035f9 编写于 作者: Y Yiqun Liu 提交者: Kexin Zhao

Fix a bug in save_inference_model and prune when the program is initailized by...

Fix a bug in save_inference_model and prune when the program is initailized by load_inference_model (#10011)

* Fix bug in save_inference_model and prune when the program is initialized by load_inference_program.

* Save the transpiled program instead.
上级 9ca578d4
...@@ -119,7 +119,7 @@ class OpDesc { ...@@ -119,7 +119,7 @@ class OpDesc {
void InferVarType(BlockDesc *block) const; void InferVarType(BlockDesc *block) const;
void MarkAsTarget() { desc_.set_is_target(true); } void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
void Flush(); void Flush();
......
...@@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) {
.def("block", &pd::ProgramDesc::MutableBlock, .def("block", &pd::ProgramDesc::MutableBlock,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("num_blocks", &pd::ProgramDesc::Size) .def("num_blocks", &pd::ProgramDesc::Size)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>) .def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("parse_from_string", .def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) { [](pd::ProgramDesc &program_desc, const std::string &data) {
...@@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) {
.def("check_attrs", &pd::OpDesc::CheckAttrs) .def("check_attrs", &pd::OpDesc::CheckAttrs)
.def("infer_shape", &pd::OpDesc::InferShape) .def("infer_shape", &pd::OpDesc::InferShape)
.def("infer_var_type", &pd::OpDesc::InferVarType) .def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>) .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block, .def("block", &pd::OpDesc::Block,
pybind11::return_value_policy::reference); pybind11::return_value_policy::reference);
......
...@@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin); ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget(); prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
} }
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), &pruned_desc);
......
...@@ -1070,6 +1070,12 @@ class Program(object): ...@@ -1070,6 +1070,12 @@ class Program(object):
for t in targets: for t in targets:
if not isinstance(t, Operator): if not isinstance(t, Operator):
if isinstance(t, Variable): if isinstance(t, Variable):
if t.op is None:
global_block = self.global_block()
for op in global_block.ops:
if t.name in op.output_arg_names:
t.op = op
break
t = t.op t = t.op
else: else:
raise ValueError(("All targets of prune() can only be " raise ValueError(("All targets of prune() can only be "
......
...@@ -340,6 +340,13 @@ def save_inference_model(dirname, ...@@ -340,6 +340,13 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op
global_block = main_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i)
pruned_program = main_program.prune(targets=target_vars) pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize() inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars] fetch_var_names = [v.name for v in target_vars]
...@@ -362,24 +369,6 @@ def save_inference_model(dirname, ...@@ -362,24 +369,6 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename) save_persistables(executor, dirname, inference_program, params_filename)
def get_feed_targets_names(program):
feed_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_targets_names.insert(0, op.desc.output('Out')[0])
return feed_targets_names
def get_fetch_targets_names(program):
fetch_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'fetch':
fetch_targets_names.append(op.desc.input('X')[0])
return fetch_targets_names
def load_inference_model(dirname, def load_inference_model(dirname,
executor, executor,
model_filename=None, model_filename=None,
...@@ -418,8 +407,8 @@ def load_inference_model(dirname, ...@@ -418,8 +407,8 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, dirname, program, params_filename)
feed_target_names = get_feed_targets_names(program) feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = get_fetch_targets_names(program) fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [ fetch_targets = [
program.global_block().var(name) for name in fetch_target_names program.global_block().var(name) for name in fetch_target_names
] ]
......
...@@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None): ...@@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None):
print("infer results: ", results[0]) print("infer results: ", results[0])
fluid.io.save_inference_model(save_dirname, feed_target_names,
fetch_targets, exe,
inference_transpiler_program)
def main(net_type, use_cuda, is_local=True): def main(net_type, use_cuda, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册