未验证 提交 103deb11 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13484 from panyx0718/ir4

To support full model saving.
......@@ -1647,7 +1647,7 @@ class Program(object):
The two code snippets above will generate same programs.
"""
if for_test:
p = self._inference_optimize(export_for_deployment=False)
p = self._inference_optimize(prune_read_op=False)
else:
p = Program()
p.current_block_idx = self.current_block_idx
......@@ -1717,7 +1717,7 @@ class Program(object):
res._sync_with_cpp()
return res
def _inference_optimize(self, export_for_deployment=True):
def _inference_optimize(self, prune_read_op=True):
"""
This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
......@@ -1729,7 +1729,7 @@ class Program(object):
information will be lost.
Args:
export_for_deployment(bool): remove the read ops that are added by py_reader
prune_read_op(bool): remove the read ops that are added by py_reader
for cpp inference library
Notes: This API is a very low level API. Use
......@@ -1744,7 +1744,7 @@ class Program(object):
# remove all readers and the read_op if exist
read_op_idx = 0
root_block = res.desc.block(0)
if export_for_deployment:
if prune_read_op:
while True:
if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read':
......
......@@ -20,6 +20,7 @@ import time
import shutil
import six
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
from . import core
......@@ -587,8 +588,11 @@ def save_inference_model(dirname,
params_filename(str|None): The name of file to save all related parameters.
If it is setted None, parameters will be saved
in separate files .
export_for_deployment(bool): remove the read ops that are added by py_reader
for cpp inference lib. Default True
export_for_deployment(bool): If True, programs are modified to only support
direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Returns:
None
......@@ -636,7 +640,11 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
if export_for_deployment:
global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
......@@ -645,12 +653,15 @@ def save_inference_model(dirname,
copy_program.desc.flush()
pruned_program = copy_program._prune(targets=target_vars)
inference_program = pruned_program._inference_optimize(
export_for_deployment=export_for_deployment)
saved_program = pruned_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names)
prepend_feed_ops(saved_program, feeded_var_names)
append_fetch_ops(saved_program, fetch_var_names)
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
saved_program = copy_program
if model_filename is not None:
model_filename = os.path.basename(model_filename)
......@@ -662,9 +673,9 @@ def save_inference_model(dirname,
params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string())
f.write(saved_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, params_filename)
save_persistables(executor, dirname, saved_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
......
......@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase):
net()
no_read_program = main_program._inference_optimize()
keep_read_program = main_program._inference_optimize(
export_for_deployment=False)
prune_read_op=False)
no_read_ops = no_read_program.global_block().ops
keep_read_ops = keep_read_program.global_block().ops
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册