提交 598035f9 编写于 作者: Y Yiqun Liu 提交者: Kexin Zhao

Fix a bug in save_inference_model and prune when the program is initailized by...

Fix a bug in save_inference_model and prune when the program is initailized by load_inference_model (#10011)

* Fix bug in save_inference_model and prune when the program is initialized by load_inference_program.

* Save the transpiled program instead.
上级 9ca578d4
......@@ -119,7 +119,7 @@ class OpDesc {
void InferVarType(BlockDesc *block) const;
void MarkAsTarget() { desc_.set_is_target(true); }
void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
void Flush();
......
......@@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) {
.def("block", &pd::ProgramDesc::MutableBlock,
pybind11::return_value_policy::reference)
.def("num_blocks", &pd::ProgramDesc::Size)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) {
......@@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) {
.def("check_attrs", &pd::OpDesc::CheckAttrs)
.def("infer_shape", &pd::OpDesc::InferShape)
.def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block,
pybind11::return_value_policy::reference);
......
......@@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
......
......@@ -1070,6 +1070,12 @@ class Program(object):
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
if t.op is None:
global_block = self.global_block()
for op in global_block.ops:
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
else:
raise ValueError(("All targets of prune() can only be "
......
......@@ -340,6 +340,13 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op
global_block = main_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i)
pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]
......@@ -362,24 +369,6 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename)
def get_feed_targets_names(program):
feed_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_targets_names.insert(0, op.desc.output('Out')[0])
return feed_targets_names
def get_fetch_targets_names(program):
fetch_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'fetch':
fetch_targets_names.append(op.desc.input('X')[0])
return fetch_targets_names
def load_inference_model(dirname,
executor,
model_filename=None,
......@@ -418,8 +407,8 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename)
feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
......
......@@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None):
print("infer results: ", results[0])
fluid.io.save_inference_model(save_dirname, feed_target_names,
fetch_targets, exe,
inference_transpiler_program)
def main(net_type, use_cuda, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册