diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6e5f7fd035acfeab975f63b0794829d57f9bb239..912721f49decea4d15627efe159908b0bb5373b8 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1161,15 +1161,7 @@ def append_fetch_ops(inference_program, @dygraph_not_support -def save_inference_model(dirname, - feeded_var_names, - target_vars, - executor, - main_program=None, - model_filename=None, - params_filename=None, - export_for_deployment=True, - program_only=False): +def save_inference_model(dirname, feed_vars, fetch_vars, executor): """ :api_attr: Static Graph @@ -1180,48 +1172,26 @@ def save_inference_model(dirname, for more details. Note: - The :code:`dirname` is used to specify the folder where inference model - structure and parameters are going to be saved. If you would like to save params of - Program in separate files, set `params_filename` None; if you would like to save all - params of Program in a single file, use `params_filename` to specify the file name. + The :code:`dirname` is used to specify the folder where inference model structure + and parameters are going to be saved. The calculation graph and weights of the model + will be stored in the model.pdinfer and params.pdinfer files respectively. Please do + not modify the default names of these two files to avoid affecting model loading. Args: dirname(str): The directory path to save the inference model. - feeded_var_names(list[str]): list of string. Names of variables that need to be fed + feed_vars(list[Variable]): list of Variable. Variables that need to be fed data during inference. - target_vars(list[Variable]): list of Variable. Variables from which we can get + fetch_vars(list[Variable]): list of Variable. Variables from which we can get inference results. executor(Executor): The executor that saves the inference model. You can refer to :ref:`api_guide_executor_en` for more details. - main_program(Program, optional): The original program, which will be pruned to - build the inference model. If is set None, - the global default :code:`_main_program_` will be used. - Default: None. - model_filename(str, optional): The name of file to save the inference program - itself. If is set None, a default filename - :code:`__model__` will be used. - params_filename(str, optional): The name of file to save all related parameters. - If it is set None, parameters will be saved - in separate files . - export_for_deployment(bool): If True, programs are modified to only support - direct inference deployment. Otherwise, - more information will be stored for flexible - optimization and re-training. Currently, only - True is supported. - Default: True. - program_only(bool, optional): If True, It will save inference program only, and do not - save params of Program. - Default: False. Returns: - The fetch variables' name list - - Return Type: - list + None Raises: - ValueError: If `feed_var_names` is not a list of basestring, an exception is thrown. - ValueError: If `target_vars` is not a list of Variable, an exception is thrown. + ValueError: If `feed_vars` is not a list of Variable, an exception is thrown. + ValueError: If `fetch_vars` is not a list of Variable, an exception is thrown. Examples: .. code-block:: python @@ -1246,35 +1216,28 @@ def save_inference_model(dirname, # Save inference model. Note we don't save label and loss in this example fluid.io.save_inference_model(dirname=path, - feeded_var_names=['img'], - target_vars=[predict], + feed_vars=[image], + fetch_vars=[predict], executor=exe) - # In this example, the save_inference_mode inference will prune the default + # In this example, the save_inference_model inference will prune the default # main program according to the network's input node (img) and output node(predict). - # The pruned inference program is going to be saved in the "./infer_model/__model__" - # and parameters are going to be saved in separate files under folder - # "./infer_model". + # The pruned inference program is going to be saved in the "./infer_model/model.pdinfer" + # and parameters are going to be saved in the "./infer_model/params.pdinfer". """ - if isinstance(feeded_var_names, six.string_types): - feeded_var_names = [feeded_var_names] - elif export_for_deployment: - if len(feeded_var_names) > 0: - # TODO(paddle-dev): polish these code blocks - if not (bool(feeded_var_names) and all( - isinstance(name, six.string_types) - for name in feeded_var_names)): - raise ValueError("'feed_var_names' should be a list of str.") - - if isinstance(target_vars, Variable): - target_vars = [target_vars] - elif export_for_deployment: - if not (bool(target_vars) and - all(isinstance(var, Variable) for var in target_vars)): - raise ValueError("'target_vars' should be a list of Variable.") + if not (bool(feed_vars) and + all(isinstance(var, Variable) for var in feed_vars)): + raise ValueError("'feed_vars' should be a list of Variable.") - main_program = _get_valid_program(main_program) + if not (bool(fetch_vars) and + all(isinstance(var, Variable) for var in fetch_vars)): + raise ValueError("'fetch_vars' should be a list of Variable.") + + feeded_var_names = [var.name for var in feed_vars] + target_vars = fetch_vars + + main_program = _get_valid_program(None) # remind user to set auc_states to zeros if the program contains auc op all_ops = main_program.global_block().ops @@ -1310,72 +1273,46 @@ def save_inference_model(dirname, if e.errno != errno.EEXIST: raise - if model_filename is not None: - model_basename = os.path.basename(model_filename) - else: - model_basename = "__model__" + model_basename = "model.pdinfer" model_basename = os.path.join(save_dirname, model_basename) - # When export_for_deployment is true, we modify the program online so that - # it can only be loaded for inference directly. If it's false, the whole - # original program and related meta are saved so that future usage can be - # more flexible. - origin_program = main_program.clone() - if export_for_deployment: - main_program = main_program.clone() - global_block = main_program.global_block() - need_to_remove_op_index = [] - for i, op in enumerate(global_block.ops): - op.desc.set_is_target(False) - if op.type == "feed" or op.type == "fetch": - need_to_remove_op_index.append(i) + main_program = main_program.clone() + global_block = main_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + need_to_remove_op_index.append(i) - for index in need_to_remove_op_index[::-1]: - global_block._remove_op(index) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) - main_program.desc.flush() + main_program.desc.flush() - main_program = main_program._prune_with_input( - feeded_var_names=feeded_var_names, targets=target_vars) - main_program = main_program._inference_optimize(prune_read_op=True) - fetch_var_names = [v.name for v in target_vars] + main_program = main_program._prune_with_input( + feeded_var_names=feeded_var_names, targets=target_vars) + main_program = main_program._inference_optimize(prune_read_op=True) + fetch_var_names = [v.name for v in target_vars] - prepend_feed_ops(main_program, feeded_var_names) - append_fetch_ops(main_program, fetch_var_names) + prepend_feed_ops(main_program, feeded_var_names) + append_fetch_ops(main_program, fetch_var_names) - main_program.desc._set_version() - paddle.fluid.core.save_op_compatible_info(main_program.desc) - with open(model_basename, "wb") as f: - f.write(main_program.desc.serialize_to_string()) - else: - # TODO(panyx0718): Save more information so that it can also be used - # for training and more flexible post-processing. - with open(model_basename + ".main_program", "wb") as f: - f.write(main_program.desc.serialize_to_string()) - - if program_only: - warnings.warn( - "save_inference_model specified the param `program_only` to True, It will not save params of Program." - ) - return target_var_name_list + main_program.desc._set_version() + paddle.fluid.core.save_op_compatible_info(main_program.desc) + with open(model_basename, "wb") as f: + f.write(main_program.desc.serialize_to_string()) main_program._copy_dist_param_info_from(origin_program) - if params_filename is not None: - params_filename = os.path.basename(params_filename) + params_filename = os.path.basename('params.pdinfer') save_persistables(executor, save_dirname, main_program, params_filename) - return target_var_name_list @dygraph_not_support -def load_inference_model(dirname, - executor, - model_filename=None, - params_filename=None, - pserver_endpoints=None): +def load_inference_model(dirname, executor): """ :api_attr: Static Graph @@ -1385,34 +1322,16 @@ def load_inference_model(dirname, You can refer to :ref:`api_guide_model_save_reader_en` for more details. Args: - dirname(str): One of the following: - - The given directory path. - - Set to None when reading the model from memory. + dirname(str): The given directory path. This directory must contain two files, + `model.pdinfer` and `params.pdinfer`. executor(Executor): The executor to run for loading inference model. See :ref:`api_guide_executor_en` for more details about it. - model_filename(str, optional): One of the following: - - The name of file to load the inference program. - - If it is None, the default filename ``__model__`` will be used. - - When ``dirname`` is ``None``, it must be set to a string containing model. - Default: ``None``. - params_filename(str, optional): It is only used for the case that all - parameters were saved in a single binary file. One of the following: - - The name of file to load all parameters. - - When ``dirname`` is ``None``, it must be set to a string containing all the parameters. - - If parameters were saved in separate files, set it as ``None``. - Default: ``None``. - - pserver_endpoints(list, optional): It is only needed by the distributed inference. - If using a distributed look up table during the training, - this table is also needed by the inference process. Its value is - a list of pserver endpoints. - Returns: list: The return of this API is a list with three elements: - (program, feed_target_names, fetch_targets). The `program` is a + (program, feed_var_names, fetch_vars). The `program` is a ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference. - The `feed_target_names` is a list of ``str``, which contains names of variables - that need to feed data in the inference program. The `fetch_targets` is a list of + The `feed_var_names` is a list of ``str``, which contains names of variables + that need to feed data in the inference program. The `fetch_vars` is a list of ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which we can get inference results. @@ -1440,60 +1359,48 @@ def load_inference_model(dirname, # Save the inference model path = "./infer_model" - fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'], - target_vars=[hidden_b], executor=exe, main_program=main_prog) + fluid.io.save_inference_model(dirname=path, feed_vars=[data], + fetch_vars=[hidden_b], executor=exe) # Demo one. Not need to set the distributed look up table, because the # training doesn't use a distributed look up table. - [inference_program, feed_target_names, fetch_targets] = ( + [inference_program, feed_var_names, fetch_vars] = ( fluid.io.load_inference_model(dirname=path, executor=exe)) tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32) results = exe.run(inference_program, - feed={feed_target_names[0]: tensor_img}, - fetch_list=fetch_targets) + feed={feed_var_names[0]: tensor_img}, + fetch_list=fetch_vars) # Demo two. If the training uses a distributed look up table, the pserver # endpoints list should be supported when loading the inference model. # The below is just an example. endpoints = ["127.0.0.1:2023","127.0.0.1:2024"] - [dist_inference_program, dist_feed_target_names, dist_fetch_targets] = ( + [dist_inference_program, dist_feed_var_names, dist_fetch_vars] = ( fluid.io.load_inference_model(dirname=path, - executor=exe, - pserver_endpoints=endpoints)) + executor=exe) + fluid.io.endpoints_replacement(dist_inference_program, pserver_endpoints) # In this example, the inference program was saved in the file - # "./infer_model/__model__" and parameters were saved in - # separate files under the directory "./infer_model". - # By the inference program, feed_target_names and - # fetch_targets, we can use an executor to run the inference - # program for getting the inference result. + # "./infer_model/model.pdinfer" and parameters were saved in + # the "./infer_model/params.pdinfer". By the inference program, + # feed_var_names and fetch_vars, we can use an executor to run + # the inference program for getting the inference result. """ - load_from_memory = False - if dirname is not None: - load_dirname = os.path.normpath(dirname) - if not os.path.isdir(load_dirname): - raise ValueError("There is no directory named '%s'" % dirname) - - if model_filename is None: - model_filename = '__model__' + load_dirname = os.path.normpath(dirname) + if not os.path.isdir(load_dirname): + raise ValueError("There is no directory named '%s'" % dirname) - model_filename = os.path.join(load_dirname, - os.path.basename(model_filename)) + model_abs_path = os.path.join(load_dirname, 'model.pdinfer') + params_filename = 'params.pdinfer' + params_abs_path = os.path.join(load_dirname, params_filename) - if params_filename is not None: - params_filename = os.path.basename(params_filename) + if not (os.path.exists(model_abs_path) and os.path.exists(params_abs_path)): + raise ValueError( + "Please check model.pdinfer and params.infer exist in '%s'" % + dirname) - with open(model_filename, "rb") as f: - program_desc_str = f.read() - else: - load_from_memory = True - if params_filename is None: - raise ValueError( - "The path of params cannot be None when the directory path is None." - ) - load_dirname = dirname - program_desc_str = model_filename - params_filename = params_filename + with open(model_abs_path, "rb") as f: + program_desc_str = f.read() program = Program.parse_from_string(program_desc_str) if not core._is_program_version_supported(program._version()): @@ -1502,19 +1409,14 @@ def load_inference_model(dirname, # Binary data also need versioning. load_persistables(executor, load_dirname, program, params_filename) - if pserver_endpoints: - program = _endpoints_replacement(program, pserver_endpoints) - - feed_target_names = program.desc.get_feed_target_names() - fetch_target_names = program.desc.get_fetch_target_names() - fetch_targets = [ - program.global_block().var(name) for name in fetch_target_names - ] + feed_var_names = program.desc.get_feed_target_names() + fetch_var_names = program.desc.get_fetch_target_names() + fetch_vars = [program.global_block().var(name) for name in fetch_var_names] - return [program, feed_target_names, fetch_targets] + return [program, feed_var_names, fetch_vars] -def _endpoints_replacement(program, endpoints): +def endpoints_replacement(program, endpoints): ENDPOINT_MAP = "epmap" for op in program.global_block().ops: if op.has_attr(ENDPOINT_MAP):