提交 10f7d004 编写于 作者: X Xin Pan

To support full model saving.

In the future, we'd like to encourage user to save
everything during training.
This allows us to
1. Do more flexible optimization passes
2. Re-train and fune-tune
上级 a69a584b
...@@ -1647,7 +1647,7 @@ class Program(object): ...@@ -1647,7 +1647,7 @@ class Program(object):
The two code snippets above will generate same programs. The two code snippets above will generate same programs.
""" """
if for_test: if for_test:
p = self._inference_optimize(export_for_deployment=False) p = self._inference_optimize(prune_read_op=False)
else: else:
p = Program() p = Program()
p.current_block_idx = self.current_block_idx p.current_block_idx = self.current_block_idx
...@@ -1717,7 +1717,7 @@ class Program(object): ...@@ -1717,7 +1717,7 @@ class Program(object):
res._sync_with_cpp() res._sync_with_cpp()
return res return res
def _inference_optimize(self, export_for_deployment=True): def _inference_optimize(self, prune_read_op=True):
""" """
This method will create a new program and do following adjustments on it: This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist. 1. Remove all reader variables and their creator ops if exist.
...@@ -1729,8 +1729,8 @@ class Program(object): ...@@ -1729,8 +1729,8 @@ class Program(object):
information will be lost. information will be lost.
Args: Args:
export_for_deployment(bool): remove the read ops that are added by py_reader prune_read_op(bool): remove the read ops that are added by py_reader
for cpp inference library for cpp inference library
Notes: This API is a very low level API. Use Notes: This API is a very low level API. Use
:code:`Program.clone(for_test=True)` instead. :code:`Program.clone(for_test=True)` instead.
...@@ -1744,7 +1744,7 @@ class Program(object): ...@@ -1744,7 +1744,7 @@ class Program(object):
# remove all readers and the read_op if exist # remove all readers and the read_op if exist
read_op_idx = 0 read_op_idx = 0
root_block = res.desc.block(0) root_block = res.desc.block(0)
if export_for_deployment: if prune_read_op:
while True: while True:
if read_op_idx >= root_block.op_size() or root_block.op( if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read': read_op_idx).type() == 'read':
......
...@@ -20,6 +20,7 @@ import time ...@@ -20,6 +20,7 @@ import time
import shutil import shutil
import six import six
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
from . import core from . import core
...@@ -587,8 +588,11 @@ def save_inference_model(dirname, ...@@ -587,8 +588,11 @@ def save_inference_model(dirname,
params_filename(str|None): The name of file to save all related parameters. params_filename(str|None): The name of file to save all related parameters.
If it is setted None, parameters will be saved If it is setted None, parameters will be saved
in separate files . in separate files .
export_for_deployment(bool): remove the read ops that are added by py_reader export_for_deployment(bool): If True, programs are modified to only support
for cpp inference lib. Default True direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Returns: Returns:
None None
...@@ -636,21 +640,28 @@ def save_inference_model(dirname, ...@@ -636,21 +640,28 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op # When export_for_deployment is true, we modify the program online so that
global_block = copy_program.global_block() # it can only be loaded for inference directly. If it's false, the whole
for i, op in enumerate(global_block.ops): # original program and related meta are saved so that future usage can be
op.desc.set_is_target(False) # more flexible.
if op.type == "feed" or op.type == "fetch": if export_for_deployment:
global_block._remove_op(i) global_block = copy_program.global_block()
copy_program.desc.flush() for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
pruned_program = copy_program._prune(targets=target_vars) if op.type == "feed" or op.type == "fetch":
inference_program = pruned_program._inference_optimize( global_block._remove_op(i)
export_for_deployment=export_for_deployment) copy_program.desc.flush()
fetch_var_names = [v.name for v in target_vars]
pruned_program = copy_program._prune(targets=target_vars)
prepend_feed_ops(inference_program, feeded_var_names) saved_program = pruned_program._inference_optimize(prune_read_op=True)
append_fetch_ops(inference_program, fetch_var_names) fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(saved_program, feeded_var_names)
append_fetch_ops(saved_program, fetch_var_names)
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
saved_program = copy_program
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
...@@ -662,9 +673,9 @@ def save_inference_model(dirname, ...@@ -662,9 +673,9 @@ def save_inference_model(dirname,
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f: with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(saved_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, params_filename) save_persistables(executor, dirname, saved_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save. # if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
......
...@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase):
net() net()
no_read_program = main_program._inference_optimize() no_read_program = main_program._inference_optimize()
keep_read_program = main_program._inference_optimize( keep_read_program = main_program._inference_optimize(
export_for_deployment=False) prune_read_op=False)
no_read_ops = no_read_program.global_block().ops no_read_ops = no_read_program.global_block().ops
keep_read_ops = keep_read_program.global_block().ops keep_read_ops = keep_read_program.global_block().ops
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2) self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册