未验证 提交 4ceedec3 编写于 作者: S Shibo Tao 提交者: GitHub

enhance doc. add kwargs for backward compatibility. test=develop (#29143)

上级 28280647
......@@ -204,7 +204,7 @@ def is_persistable(var):
@static_only
def serialize_program(feed_vars, fetch_vars):
def serialize_program(feed_vars, fetch_vars, **kwargs):
"""
:api_attr: Static Graph
......@@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars):
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
bytes: serialized program.
......@@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
......@@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars):
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)
......@@ -265,7 +268,7 @@ def _serialize_program(program):
@static_only
def serialize_persistables(feed_vars, fetch_vars, executor):
def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
"""
:api_attr: Static Graph
......@@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
bytes: serialized program.
......@@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
......@@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor)
......@@ -380,7 +386,8 @@ def save_to_file(path, content):
@static_only
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
**kwargs):
"""
:api_attr: Static Graph
......@@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
fetch_vars(Variable | list[Variable]): Variables returned by inference.
executor(Executor): The executor that saves the inference model. You can refer
to :ref:`api_guide_executor_en` for more details.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
None
......@@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
......@@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _get_valid_program(kwargs.get('program', None))
program = _normalize_program(program, feed_vars, fetch_vars)
# serialize and save program
program_bytes = _serialize_program(program)
......@@ -475,8 +484,35 @@ def deserialize_program(data):
Args:
data(bytes): serialized program.
Returns:
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
"""
program = Program.parse_from_string(data)
if not core._is_program_version_supported(program._version()):
......@@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor):
program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters.
executor(Executor): executor used to run load op.
Returns:
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
"""
if not isinstance(program, Program):
raise TypeError(
......@@ -567,7 +632,7 @@ def load_from_file(path):
@static_only
def load_inference_model(path_prefix, executor, **configs):
def load_inference_model(path_prefix, executor, **kwargs):
"""
:api_attr: Static Graph
......@@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs):
- Set to None when reading the model from memory.
executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it.
kwargs: Supported keys including 'model_filename', 'params_filename'.
Attention please, kwargs is used for backward compatibility mainly.
- model_filename(str): specify model_filename if you don't want to use default name.
- params_filename(str): specify params_filename if you don't want to use default name.
Returns:
list: The return of this API is a list with three elements:
......@@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs):
# fetch_targets, we can use an executor to run the inference
# program to get the inference result.
"""
# check configs
# check kwargs
supported_args = ('model_filename', 'params_filename')
deprecated_args = ('pserver_endpoints', )
caller = inspect.currentframe().f_code.co_name
_check_args(caller, configs, supported_args, deprecated_args)
_check_args(caller, kwargs, supported_args, deprecated_args)
# load from memory
if path_prefix is None:
_logger.warning("Load inference model from memory is deprecated.")
model_filename = configs.get('model_filename', None)
params_filename = configs.get('params_filename', None)
model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None)
if params_filename is None:
raise ValueError(
"params_filename cannot be None when path_prefix is None.")
......@@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs):
# set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case.
if not configs:
if not kwargs:
model_path = path_prefix + ".pdmodel"
params_path = path_prefix + ".pdiparams"
# set model_path and params_path in old way for compatible,
# path_prefix represents a directory path.
else:
model_filename = configs.get('model_filename', None)
params_filename = configs.get('params_filename', None)
model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None)
# set model_path
if model_filename is None:
model_path = os.path.join(path_prefix, "__model__")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册