未验证 提交 4ceedec3 编写于 作者: S Shibo Tao 提交者: GitHub

enhance doc. add kwargs for backward compatibility. test=develop (#29143)

上级 28280647
...@@ -204,7 +204,7 @@ def is_persistable(var): ...@@ -204,7 +204,7 @@ def is_persistable(var):
@static_only @static_only
def serialize_program(feed_vars, fetch_vars): def serialize_program(feed_vars, fetch_vars, **kwargs):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars): ...@@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars):
Args: Args:
feed_vars(Variable | list[Variable]): Variables needed by inference. feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns: Returns:
bytes: serialized program. bytes: serialized program.
...@@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars): ...@@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars):
predict = paddle.static.nn.fc(image, 10, activation='softmax') predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label) loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
...@@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars): ...@@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars):
# verify fetch_vars # verify fetch_vars
_check_vars('fetch_vars', fetch_vars) _check_vars('fetch_vars', fetch_vars)
program = _get_valid_program() program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars) program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program) return _serialize_program(program)
...@@ -265,7 +268,7 @@ def _serialize_program(program): ...@@ -265,7 +268,7 @@ def _serialize_program(program):
@static_only @static_only
def serialize_persistables(feed_vars, fetch_vars, executor): def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor): ...@@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
Args: Args:
feed_vars(Variable | list[Variable]): Variables needed by inference. feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns: Returns:
bytes: serialized program. bytes: serialized program.
...@@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor): ...@@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax') predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label) loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
...@@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor): ...@@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
# verify fetch_vars # verify fetch_vars
_check_vars('fetch_vars', fetch_vars) _check_vars('fetch_vars', fetch_vars)
program = _get_valid_program() program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars) program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor) return _serialize_persistables(program, executor)
...@@ -380,7 +386,8 @@ def save_to_file(path, content): ...@@ -380,7 +386,8 @@ def save_to_file(path, content):
@static_only @static_only
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
**kwargs):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): ...@@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
fetch_vars(Variable | list[Variable]): Variables returned by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference.
executor(Executor): The executor that saves the inference model. You can refer executor(Executor): The executor that saves the inference model. You can refer
to :ref:`api_guide_executor_en` for more details. to :ref:`api_guide_executor_en` for more details.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns: Returns:
None None
...@@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): ...@@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax') predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label) loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
...@@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): ...@@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
# verify fetch_vars # verify fetch_vars
_check_vars('fetch_vars', fetch_vars) _check_vars('fetch_vars', fetch_vars)
program = _get_valid_program() program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars) program = _normalize_program(program, feed_vars, fetch_vars)
# serialize and save program # serialize and save program
program_bytes = _serialize_program(program) program_bytes = _serialize_program(program)
...@@ -475,8 +484,35 @@ def deserialize_program(data): ...@@ -475,8 +484,35 @@ def deserialize_program(data):
Args: Args:
data(bytes): serialized program. data(bytes): serialized program.
Returns: Returns:
Program: deserialized program. Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
""" """
program = Program.parse_from_string(data) program = Program.parse_from_string(data)
if not core._is_program_version_supported(program._version()): if not core._is_program_version_supported(program._version()):
...@@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor): ...@@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor):
program(Program): program that contains parameter names (to deserialize). program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters. data(bytes): serialized parameters.
executor(Executor): executor used to run load op. executor(Executor): executor used to run load op.
Returns: Returns:
Program: deserialized program. Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
""" """
if not isinstance(program, Program): if not isinstance(program, Program):
raise TypeError( raise TypeError(
...@@ -567,7 +632,7 @@ def load_from_file(path): ...@@ -567,7 +632,7 @@ def load_from_file(path):
@static_only @static_only
def load_inference_model(path_prefix, executor, **configs): def load_inference_model(path_prefix, executor, **kwargs):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs): ...@@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs):
- Set to None when reading the model from memory. - Set to None when reading the model from memory.
executor(Executor): The executor to run for loading inference model. executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it. See :ref:`api_guide_executor_en` for more details about it.
kwargs: Supported keys including 'model_filename', 'params_filename'.
Attention please, kwargs is used for backward compatibility mainly.
- model_filename(str): specify model_filename if you don't want to use default name.
- params_filename(str): specify params_filename if you don't want to use default name.
Returns: Returns:
list: The return of this API is a list with three elements: list: The return of this API is a list with three elements:
...@@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs): ...@@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs):
# fetch_targets, we can use an executor to run the inference # fetch_targets, we can use an executor to run the inference
# program to get the inference result. # program to get the inference result.
""" """
# check configs # check kwargs
supported_args = ('model_filename', 'params_filename') supported_args = ('model_filename', 'params_filename')
deprecated_args = ('pserver_endpoints', ) deprecated_args = ('pserver_endpoints', )
caller = inspect.currentframe().f_code.co_name caller = inspect.currentframe().f_code.co_name
_check_args(caller, configs, supported_args, deprecated_args) _check_args(caller, kwargs, supported_args, deprecated_args)
# load from memory # load from memory
if path_prefix is None: if path_prefix is None:
_logger.warning("Load inference model from memory is deprecated.") _logger.warning("Load inference model from memory is deprecated.")
model_filename = configs.get('model_filename', None) model_filename = kwargs.get('model_filename', None)
params_filename = configs.get('params_filename', None) params_filename = kwargs.get('params_filename', None)
if params_filename is None: if params_filename is None:
raise ValueError( raise ValueError(
"params_filename cannot be None when path_prefix is None.") "params_filename cannot be None when path_prefix is None.")
...@@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs): ...@@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs):
# set model_path and params_path in new way, # set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case. # path_prefix represents a file path without suffix in this case.
if not configs: if not kwargs:
model_path = path_prefix + ".pdmodel" model_path = path_prefix + ".pdmodel"
params_path = path_prefix + ".pdiparams" params_path = path_prefix + ".pdiparams"
# set model_path and params_path in old way for compatible, # set model_path and params_path in old way for compatible,
# path_prefix represents a directory path. # path_prefix represents a directory path.
else: else:
model_filename = configs.get('model_filename', None) model_filename = kwargs.get('model_filename', None)
params_filename = configs.get('params_filename', None) params_filename = kwargs.get('params_filename', None)
# set model_path # set model_path
if model_filename is None: if model_filename is None:
model_path = os.path.join(path_prefix, "__model__") model_path = os.path.join(path_prefix, "__model__")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册