提交 10f7d004 编写于 作者: X Xin Pan

To support full model saving.

In the future, we'd like to encourage user to save
everything during training.
This allows us to
1. Do more flexible optimization passes
2. Re-train and fune-tune
上级 a69a584b
......@@ -1647,7 +1647,7 @@ class Program(object):
The two code snippets above will generate same programs.
"""
if for_test:
p = self._inference_optimize(export_for_deployment=False)
p = self._inference_optimize(prune_read_op=False)
else:
p = Program()
p.current_block_idx = self.current_block_idx
......@@ -1717,7 +1717,7 @@ class Program(object):
res._sync_with_cpp()
return res
def _inference_optimize(self, export_for_deployment=True):
def _inference_optimize(self, prune_read_op=True):
"""
This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
......@@ -1729,8 +1729,8 @@ class Program(object):
information will be lost.
Args:
export_for_deployment(bool): remove the read ops that are added by py_reader
for cpp inference library
prune_read_op(bool): remove the read ops that are added by py_reader
for cpp inference library
Notes: This API is a very low level API. Use
:code:`Program.clone(for_test=True)` instead.
......@@ -1744,7 +1744,7 @@ class Program(object):
# remove all readers and the read_op if exist
read_op_idx = 0
root_block = res.desc.block(0)
if export_for_deployment:
if prune_read_op:
while True:
if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read':
......
......@@ -20,6 +20,7 @@ import time
import shutil
import six
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
from . import core
......@@ -587,8 +588,11 @@ def save_inference_model(dirname,
params_filename(str|None): The name of file to save all related parameters.
If it is setted None, parameters will be saved
in separate files .
export_for_deployment(bool): remove the read ops that are added by py_reader
for cpp inference lib. Default True
export_for_deployment(bool): If True, programs are modified to only support
direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Returns:
None
......@@ -636,21 +640,28 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)
# Clear the is_target information and remove the existed feed and fetch op
global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block._remove_op(i)
copy_program.desc.flush()
pruned_program = copy_program._prune(targets=target_vars)
inference_program = pruned_program._inference_optimize(
export_for_deployment=export_for_deployment)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names)
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
if export_for_deployment:
global_block = copy_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block._remove_op(i)
copy_program.desc.flush()
pruned_program = copy_program._prune(targets=target_vars)
saved_program = pruned_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(saved_program, feeded_var_names)
append_fetch_ops(saved_program, fetch_var_names)
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
saved_program = copy_program
if model_filename is not None:
model_filename = os.path.basename(model_filename)
......@@ -662,9 +673,9 @@ def save_inference_model(dirname,
params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f:
f.write(inference_program.desc.serialize_to_string())
f.write(saved_program.desc.serialize_to_string())
save_persistables(executor, dirname, inference_program, params_filename)
save_persistables(executor, dirname, saved_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
......
......@@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase):
net()
no_read_program = main_program._inference_optimize()
keep_read_program = main_program._inference_optimize(
export_for_deployment=False)
prune_read_op=False)
no_read_ops = no_read_program.global_block().ops
keep_read_ops = keep_read_program.global_block().ops
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册