未验证 提交 103deb11 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13484 from panyx0718/ir4

To support full model saving.
...@@ -1647,7 +1647,7 @@ class Program(object): ...@@ -1647,7 +1647,7 @@ class Program(object):
The two code snippets above will generate same programs. The two code snippets above will generate same programs.
""" """
if for_test: if for_test:
p = self._inference_optimize(export_for_deployment=False) p = self._inference_optimize(prune_read_op=False)
else: else:
p = Program() p = Program()
p.current_block_idx = self.current_block_idx p.current_block_idx = self.current_block_idx
...@@ -1717,7 +1717,7 @@ class Program(object): ...@@ -1717,7 +1717,7 @@ class Program(object):
res._sync_with_cpp() res._sync_with_cpp()
return res return res
def _inference_optimize(self, export_for_deployment=True): def _inference_optimize(self, prune_read_op=True):
""" """
This method will create a new program and do following adjustments on it: This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist. 1. Remove all reader variables and their creator ops if exist.
...@@ -1729,7 +1729,7 @@ class Program(object): ...@@ -1729,7 +1729,7 @@ class Program(object):
information will be lost. information will be lost.
Args: Args:
export_for_deployment(bool): remove the read ops that are added by py_reader prune_read_op(bool): remove the read ops that are added by py_reader
for cpp inference library for cpp inference library
Notes: This API is a very low level API. Use Notes: This API is a very low level API. Use
...@@ -1744,7 +1744,7 @@ class Program(object): ...@@ -1744,7 +1744,7 @@ class Program(object):
# remove all readers and the read_op if exist # remove all readers and the read_op if exist
read_op_idx = 0 read_op_idx = 0
root_block = res.desc.block(0) root_block = res.desc.block(0)
if export_for_deployment: if prune_read_op:
while True: while True:
if read_op_idx >= root_block.op_size() or root_block.op( if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read': read_op_idx).type() == 'read':
......
...@@ -20,6 +20,7 @@ import time ...@@ -20,6 +20,7 @@ import time
import shutil import shutil
import six import six
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
from . import core from . import core
...@@ -587,8 +588,11 @@ def save_inference_model(dirname, ...@@ -587,8 +588,11 @@ def save_inference_model(dirname,
params_filename(str|None): The name of file to save all related parameters. params_filename(str|None): The name of file to save all related parameters.
If it is setted None, parameters will be saved If it is setted None, parameters will be saved
in separate files . in separate files .
export_for_deployment(bool): remove the read ops that are added by py_reader export_for_deployment(bool): If True, programs are modified to only support
for cpp inference lib. Default True direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Returns: Returns:
None None
...@@ -636,7 +640,11 @@ def save_inference_model(dirname, ...@@ -636,7 +640,11 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op # When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
if export_for_deployment:
global_block = copy_program.global_block() global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops): for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False) op.desc.set_is_target(False)
...@@ -645,12 +653,15 @@ def save_inference_model(dirname, ...@@ -645,12 +653,15 @@ def save_inference_model(dirname,
copy_program.desc.flush() copy_program.desc.flush()
pruned_program = copy_program._prune(targets=target_vars) pruned_program = copy_program._prune(targets=target_vars)
inference_program = pruned_program._inference_optimize( saved_program = pruned_program._inference_optimize(prune_read_op=True)
export_for_deployment=export_for_deployment)
fetch_var_names = [v.name for v in target_vars] fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(inference_program, feeded_var_names) prepend_feed_ops(saved_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names) append_fetch_ops(saved_program, fetch_var_names)
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
saved_program = copy_program
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
...@@ -662,9 +673,9 @@ def save_inference_model(dirname, ...@@ -662,9 +673,9 @@ def save_inference_model(dirname,
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f: with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(saved_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, params_filename) save_persistables(executor, dirname, saved_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save. # if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
......
...@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase):
net() net()
no_read_program = main_program._inference_optimize() no_read_program = main_program._inference_optimize()
keep_read_program = main_program._inference_optimize( keep_read_program = main_program._inference_optimize(
export_for_deployment=False) prune_read_op=False)
no_read_ops = no_read_program.global_block().ops no_read_ops = no_read_program.global_block().ops
keep_read_ops = keep_read_program.global_block().ops keep_read_ops = keep_read_program.global_block().ops
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2) self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册