提交 ae9d4958 编写于 作者: S Shixiaowei02

upgrade io.inference_model interfaces, test=develop

上级 39546aa2
......@@ -1161,15 +1161,7 @@ def append_fetch_ops(inference_program,
@dygraph_not_support
def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
main_program=None,
model_filename=None,
params_filename=None,
export_for_deployment=True,
program_only=False):
def save_inference_model(dirname, feed_vars, fetch_vars, executor):
"""
:api_attr: Static Graph
......@@ -1180,48 +1172,26 @@ def save_inference_model(dirname,
for more details.
Note:
The :code:`dirname` is used to specify the folder where inference model
structure and parameters are going to be saved. If you would like to save params of
Program in separate files, set `params_filename` None; if you would like to save all
params of Program in a single file, use `params_filename` to specify the file name.
The :code:`dirname` is used to specify the folder where inference model structure
and parameters are going to be saved. The calculation graph and weights of the model
will be stored in the model.pdinfer and params.pdinfer files respectively. Please do
not modify the default names of these two files to avoid affecting model loading.
Args:
dirname(str): The directory path to save the inference model.
feeded_var_names(list[str]): list of string. Names of variables that need to be fed
feed_vars(list[Variable]): list of Variable. Variables that need to be fed
data during inference.
target_vars(list[Variable]): list of Variable. Variables from which we can get
fetch_vars(list[Variable]): list of Variable. Variables from which we can get
inference results.
executor(Executor): The executor that saves the inference model. You can refer
to :ref:`api_guide_executor_en` for more details.
main_program(Program, optional): The original program, which will be pruned to
build the inference model. If is set None,
the global default :code:`_main_program_` will be used.
Default: None.
model_filename(str, optional): The name of file to save the inference program
itself. If is set None, a default filename
:code:`__model__` will be used.
params_filename(str, optional): The name of file to save all related parameters.
If it is set None, parameters will be saved
in separate files .
export_for_deployment(bool): If True, programs are modified to only support
direct inference deployment. Otherwise,
more information will be stored for flexible
optimization and re-training. Currently, only
True is supported.
Default: True.
program_only(bool, optional): If True, It will save inference program only, and do not
save params of Program.
Default: False.
Returns:
The fetch variables' name list
Return Type:
list
None
Raises:
ValueError: If `feed_var_names` is not a list of basestring, an exception is thrown.
ValueError: If `target_vars` is not a list of Variable, an exception is thrown.
ValueError: If `feed_vars` is not a list of Variable, an exception is thrown.
ValueError: If `fetch_vars` is not a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
......@@ -1246,35 +1216,28 @@ def save_inference_model(dirname,
# Save inference model. Note we don't save label and loss in this example
fluid.io.save_inference_model(dirname=path,
feeded_var_names=['img'],
target_vars=[predict],
feed_vars=[image],
fetch_vars=[predict],
executor=exe)
# In this example, the save_inference_mode inference will prune the default
# In this example, the save_inference_model inference will prune the default
# main program according to the network's input node (img) and output node(predict).
# The pruned inference program is going to be saved in the "./infer_model/__model__"
# and parameters are going to be saved in separate files under folder
# "./infer_model".
# The pruned inference program is going to be saved in the "./infer_model/model.pdinfer"
# and parameters are going to be saved in the "./infer_model/params.pdinfer".
"""
if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names]
elif export_for_deployment:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.string_types)
for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
elif export_for_deployment:
if not (bool(target_vars) and
all(isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
if not (bool(feed_vars) and
all(isinstance(var, Variable) for var in feed_vars)):
raise ValueError("'feed_vars' should be a list of Variable.")
main_program = _get_valid_program(main_program)
if not (bool(fetch_vars) and
all(isinstance(var, Variable) for var in fetch_vars)):
raise ValueError("'fetch_vars' should be a list of Variable.")
feeded_var_names = [var.name for var in feed_vars]
target_vars = fetch_vars
main_program = _get_valid_program(None)
# remind user to set auc_states to zeros if the program contains auc op
all_ops = main_program.global_block().ops
......@@ -1310,72 +1273,46 @@ def save_inference_model(dirname,
if e.errno != errno.EEXIST:
raise
if model_filename is not None:
model_basename = os.path.basename(model_filename)
else:
model_basename = "__model__"
model_basename = "model.pdinfer"
model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
origin_program = main_program.clone()
if export_for_deployment:
main_program = main_program.clone()
global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
main_program = main_program.clone()
global_block = main_program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
main_program.desc.flush()
main_program.desc.flush()
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
main_program = main_program._prune_with_input(
feeded_var_names=feeded_var_names, targets=target_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
else:
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string())
if program_only:
warnings.warn(
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
)
return target_var_name_list
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
params_filename = os.path.basename('params.pdinfer')
save_persistables(executor, save_dirname, main_program, params_filename)
return target_var_name_list
@dygraph_not_support
def load_inference_model(dirname,
executor,
model_filename=None,
params_filename=None,
pserver_endpoints=None):
def load_inference_model(dirname, executor):
"""
:api_attr: Static Graph
......@@ -1385,34 +1322,16 @@ def load_inference_model(dirname,
You can refer to :ref:`api_guide_model_save_reader_en` for more details.
Args:
dirname(str): One of the following:
- The given directory path.
- Set to None when reading the model from memory.
dirname(str): The given directory path. This directory must contain two files,
`model.pdinfer` and `params.pdinfer`.
executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it.
model_filename(str, optional): One of the following:
- The name of file to load the inference program.
- If it is None, the default filename ``__model__`` will be used.
- When ``dirname`` is ``None``, it must be set to a string containing model.
Default: ``None``.
params_filename(str, optional): It is only used for the case that all
parameters were saved in a single binary file. One of the following:
- The name of file to load all parameters.
- When ``dirname`` is ``None``, it must be set to a string containing all the parameters.
- If parameters were saved in separate files, set it as ``None``.
Default: ``None``.
pserver_endpoints(list, optional): It is only needed by the distributed inference.
If using a distributed look up table during the training,
this table is also needed by the inference process. Its value is
a list of pserver endpoints.
Returns:
list: The return of this API is a list with three elements:
(program, feed_target_names, fetch_targets). The `program` is a
(program, feed_var_names, fetch_vars). The `program` is a
``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
The `feed_target_names` is a list of ``str``, which contains names of variables
that need to feed data in the inference program. The `fetch_targets` is a list of
The `feed_var_names` is a list of ``str``, which contains names of variables
that need to feed data in the inference program. The `fetch_vars` is a list of
``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
we can get inference results.
......@@ -1440,60 +1359,48 @@ def load_inference_model(dirname,
# Save the inference model
path = "./infer_model"
fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],
target_vars=[hidden_b], executor=exe, main_program=main_prog)
fluid.io.save_inference_model(dirname=path, feed_vars=[data],
fetch_vars=[hidden_b], executor=exe)
# Demo one. Not need to set the distributed look up table, because the
# training doesn't use a distributed look up table.
[inference_program, feed_target_names, fetch_targets] = (
[inference_program, feed_var_names, fetch_vars] = (
fluid.io.load_inference_model(dirname=path, executor=exe))
tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
feed={feed_var_names[0]: tensor_img},
fetch_list=fetch_vars)
# Demo two. If the training uses a distributed look up table, the pserver
# endpoints list should be supported when loading the inference model.
# The below is just an example.
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
[dist_inference_program, dist_feed_var_names, dist_fetch_vars] = (
fluid.io.load_inference_model(dirname=path,
executor=exe,
pserver_endpoints=endpoints))
executor=exe)
fluid.io.endpoints_replacement(dist_inference_program, pserver_endpoints)
# In this example, the inference program was saved in the file
# "./infer_model/__model__" and parameters were saved in
# separate files under the directory "./infer_model".
# By the inference program, feed_target_names and
# fetch_targets, we can use an executor to run the inference
# program for getting the inference result.
# "./infer_model/model.pdinfer" and parameters were saved in
# the "./infer_model/params.pdinfer". By the inference program,
# feed_var_names and fetch_vars, we can use an executor to run
# the inference program for getting the inference result.
"""
load_from_memory = False
if dirname is not None:
load_dirname = os.path.normpath(dirname)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'" % dirname)
if model_filename is None:
model_filename = '__model__'
load_dirname = os.path.normpath(dirname)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'" % dirname)
model_filename = os.path.join(load_dirname,
os.path.basename(model_filename))
model_abs_path = os.path.join(load_dirname, 'model.pdinfer')
params_filename = 'params.pdinfer'
params_abs_path = os.path.join(load_dirname, params_filename)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
if not (os.path.exists(model_abs_path) and os.path.exists(params_abs_path)):
raise ValueError(
"Please check model.pdinfer and params.infer exist in '%s'" %
dirname)
with open(model_filename, "rb") as f:
program_desc_str = f.read()
else:
load_from_memory = True
if params_filename is None:
raise ValueError(
"The path of params cannot be None when the directory path is None."
)
load_dirname = dirname
program_desc_str = model_filename
params_filename = params_filename
with open(model_abs_path, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()):
......@@ -1502,19 +1409,14 @@ def load_inference_model(dirname,
# Binary data also need versioning.
load_persistables(executor, load_dirname, program, params_filename)
if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
feed_var_names = program.desc.get_feed_target_names()
fetch_var_names = program.desc.get_fetch_target_names()
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
return [program, feed_target_names, fetch_targets]
return [program, feed_var_names, fetch_vars]
def _endpoints_replacement(program, endpoints):
def endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.has_attr(ENDPOINT_MAP):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册