diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a5a3a70828abf87594dfa2b90f6bf6dab6b9fe8b..0abbb6815123f8ba65b637b3f3accef91fe66ef8 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1647,7 +1647,7 @@ class Program(object): The two code snippets above will generate same programs. """ if for_test: - p = self._inference_optimize(export_for_deployment=False) + p = self._inference_optimize(prune_read_op=False) else: p = Program() p.current_block_idx = self.current_block_idx @@ -1717,7 +1717,7 @@ class Program(object): res._sync_with_cpp() return res - def _inference_optimize(self, export_for_deployment=True): + def _inference_optimize(self, prune_read_op=True): """ This method will create a new program and do following adjustments on it: 1. Remove all reader variables and their creator ops if exist. @@ -1729,8 +1729,8 @@ class Program(object): information will be lost. Args: - export_for_deployment(bool): remove the read ops that are added by py_reader - for cpp inference library + prune_read_op(bool): remove the read ops that are added by py_reader + for cpp inference library Notes: This API is a very low level API. Use :code:`Program.clone(for_test=True)` instead. @@ -1744,7 +1744,7 @@ class Program(object): # remove all readers and the read_op if exist read_op_idx = 0 root_block = res.desc.block(0) - if export_for_deployment: + if prune_read_op: while True: if read_op_idx >= root_block.op_size() or root_block.op( read_op_idx).type() == 'read': diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index af653970418275548f0810afcf1dae173d9cb171..78bb8a1a0a64631cbe2adc11b1494ceed6d14908 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -20,6 +20,7 @@ import time import shutil import six +from paddle.fluid.executor import Executor from paddle.fluid.evaluator import Evaluator from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable from . import core @@ -587,8 +588,11 @@ def save_inference_model(dirname, params_filename(str|None): The name of file to save all related parameters. If it is setted None, parameters will be saved in separate files . - export_for_deployment(bool): remove the read ops that are added by py_reader - for cpp inference lib. Default True + export_for_deployment(bool): If True, programs are modified to only support + direct inference deployment. Otherwise, + more information will be stored for flexible + optimization and re-training. Currently, only + True is supported. Returns: None @@ -636,21 +640,28 @@ def save_inference_model(dirname, if not os.path.isdir(dirname): os.makedirs(dirname) - # Clear the is_target information and remove the existed feed and fetch op - global_block = copy_program.global_block() - for i, op in enumerate(global_block.ops): - op.desc.set_is_target(False) - if op.type == "feed" or op.type == "fetch": - global_block._remove_op(i) - copy_program.desc.flush() - - pruned_program = copy_program._prune(targets=target_vars) - inference_program = pruned_program._inference_optimize( - export_for_deployment=export_for_deployment) - fetch_var_names = [v.name for v in target_vars] - - prepend_feed_ops(inference_program, feeded_var_names) - append_fetch_ops(inference_program, fetch_var_names) + # When export_for_deployment is true, we modify the program online so that + # it can only be loaded for inference directly. If it's false, the whole + # original program and related meta are saved so that future usage can be + # more flexible. + if export_for_deployment: + global_block = copy_program.global_block() + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + global_block._remove_op(i) + copy_program.desc.flush() + + pruned_program = copy_program._prune(targets=target_vars) + saved_program = pruned_program._inference_optimize(prune_read_op=True) + fetch_var_names = [v.name for v in target_vars] + + prepend_feed_ops(saved_program, feeded_var_names) + append_fetch_ops(saved_program, fetch_var_names) + else: + # TODO(panyx0718): Save more information so that it can also be used + # for training and more flexible post-processing. + saved_program = copy_program if model_filename is not None: model_filename = os.path.basename(model_filename) @@ -662,9 +673,9 @@ def save_inference_model(dirname, params_filename = os.path.basename(params_filename) with open(model_filename, "wb") as f: - f.write(inference_program.desc.serialize_to_string()) + f.write(saved_program.desc.serialize_to_string()) - save_persistables(executor, dirname, inference_program, params_filename) + save_persistables(executor, dirname, saved_program, params_filename) # if there is lookup table, the trainer 0 will notify all pserver to save. if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: diff --git a/python/paddle/fluid/tests/unittests/test_program.py b/python/paddle/fluid/tests/unittests/test_program.py index 0b9fba5fe376474b084fd233ace41c9c0cd53547..cb1d94809b4ba99fa9077f99b93689504415b71d 100644 --- a/python/paddle/fluid/tests/unittests/test_program.py +++ b/python/paddle/fluid/tests/unittests/test_program.py @@ -122,7 +122,7 @@ class TestProgram(unittest.TestCase): net() no_read_program = main_program._inference_optimize() keep_read_program = main_program._inference_optimize( - export_for_deployment=False) + prune_read_op=False) no_read_ops = no_read_program.global_block().ops keep_read_ops = keep_read_program.global_block().ops self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)