diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index cfaa6d9470439ea4c05c12892ad7db4c0d887c60..e88a052730414192cd4bc99b3abc2c6b46377c3e 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -204,7 +204,7 @@ def is_persistable(var): @static_only -def serialize_program(feed_vars, fetch_vars): +def serialize_program(feed_vars, fetch_vars, **kwargs): """ :api_attr: Static Graph @@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars): Args: feed_vars(Variable | list[Variable]): Variables needed by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference. + kwargs: Supported keys including 'program'. + Attention please, kwargs is used for backward compatibility mainly. + - program(Program): specify a program if you don't want to use default main program. + Returns: bytes: serialized program. @@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars): predict = paddle.static.nn.fc(image, 10, activation='softmax') loss = paddle.nn.functional.cross_entropy(predict, label) - avg_loss = paddle.tensor.stat.mean(loss) exe = paddle.static.Executor(paddle.CPUPlace()) exe.run(paddle.static.default_startup_program()) @@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars): # verify fetch_vars _check_vars('fetch_vars', fetch_vars) - program = _get_valid_program() + program = _get_valid_program(kwargs.get('program', None)) program = _normalize_program(program, feed_vars, fetch_vars) return _serialize_program(program) @@ -265,7 +268,7 @@ def _serialize_program(program): @static_only -def serialize_persistables(feed_vars, fetch_vars, executor): +def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs): """ :api_attr: Static Graph @@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor): Args: feed_vars(Variable | list[Variable]): Variables needed by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference. + kwargs: Supported keys including 'program'. + Attention please, kwargs is used for backward compatibility mainly. + - program(Program): specify a program if you don't want to use default main program. + Returns: bytes: serialized program. @@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor): predict = paddle.static.nn.fc(image, 10, activation='softmax') loss = paddle.nn.functional.cross_entropy(predict, label) - avg_loss = paddle.tensor.stat.mean(loss) exe = paddle.static.Executor(paddle.CPUPlace()) exe.run(paddle.static.default_startup_program()) @@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor): # verify fetch_vars _check_vars('fetch_vars', fetch_vars) - program = _get_valid_program() + program = _get_valid_program(kwargs.get('program', None)) program = _normalize_program(program, feed_vars, fetch_vars) return _serialize_persistables(program, executor) @@ -380,7 +386,8 @@ def save_to_file(path, content): @static_only -def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): +def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, + **kwargs): """ :api_attr: Static Graph @@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): fetch_vars(Variable | list[Variable]): Variables returned by inference. executor(Executor): The executor that saves the inference model. You can refer to :ref:`api_guide_executor_en` for more details. + kwargs: Supported keys including 'program'. + Attention please, kwargs is used for backward compatibility mainly. + - program(Program): specify a program if you don't want to use default main program. Returns: None @@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): predict = paddle.static.nn.fc(image, 10, activation='softmax') loss = paddle.nn.functional.cross_entropy(predict, label) - avg_loss = paddle.tensor.stat.mean(loss) exe = paddle.static.Executor(paddle.CPUPlace()) exe.run(paddle.static.default_startup_program()) @@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): # verify fetch_vars _check_vars('fetch_vars', fetch_vars) - program = _get_valid_program() + program = _get_valid_program(kwargs.get('program', None)) program = _normalize_program(program, feed_vars, fetch_vars) # serialize and save program program_bytes = _serialize_program(program) @@ -475,8 +484,35 @@ def deserialize_program(data): Args: data(bytes): serialized program. + Returns: Program: deserialized program. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = paddle.static.nn.fc(image, 10, activation='softmax') + + loss = paddle.nn.functional.cross_entropy(predict, label) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + # serialize the default main program to bytes. + serialized_program = paddle.static.serialize_program([image], [predict]) + + # deserialize bytes to program + deserialized_program = paddle.static.deserialize_program(serialized_program) + """ program = Program.parse_from_string(data) if not core._is_program_version_supported(program._version()): @@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor): program(Program): program that contains parameter names (to deserialize). data(bytes): serialized parameters. executor(Executor): executor used to run load op. + Returns: Program: deserialized program. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = paddle.static.nn.fc(image, 10, activation='softmax') + + loss = paddle.nn.functional.cross_entropy(predict, label) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + # serialize parameters to bytes. + serialized_params = paddle.static.serialize_persistables([image], [predict], exe) + + # deserialize bytes to parameters. + main_program = paddle.static.default_main_program() + deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe) + + """ if not isinstance(program, Program): raise TypeError( @@ -567,7 +632,7 @@ def load_from_file(path): @static_only -def load_inference_model(path_prefix, executor, **configs): +def load_inference_model(path_prefix, executor, **kwargs): """ :api_attr: Static Graph @@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs): - Set to None when reading the model from memory. executor(Executor): The executor to run for loading inference model. See :ref:`api_guide_executor_en` for more details about it. + kwargs: Supported keys including 'model_filename', 'params_filename'. + Attention please, kwargs is used for backward compatibility mainly. + - model_filename(str): specify model_filename if you don't want to use default name. + - params_filename(str): specify params_filename if you don't want to use default name. Returns: list: The return of this API is a list with three elements: @@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs): # fetch_targets, we can use an executor to run the inference # program to get the inference result. """ - # check configs + # check kwargs supported_args = ('model_filename', 'params_filename') deprecated_args = ('pserver_endpoints', ) caller = inspect.currentframe().f_code.co_name - _check_args(caller, configs, supported_args, deprecated_args) + _check_args(caller, kwargs, supported_args, deprecated_args) # load from memory if path_prefix is None: _logger.warning("Load inference model from memory is deprecated.") - model_filename = configs.get('model_filename', None) - params_filename = configs.get('params_filename', None) + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) if params_filename is None: raise ValueError( "params_filename cannot be None when path_prefix is None.") @@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs): # set model_path and params_path in new way, # path_prefix represents a file path without suffix in this case. - if not configs: + if not kwargs: model_path = path_prefix + ".pdmodel" params_path = path_prefix + ".pdiparams" # set model_path and params_path in old way for compatible, # path_prefix represents a directory path. else: - model_filename = configs.get('model_filename', None) - params_filename = configs.get('params_filename', None) + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) # set model_path if model_filename is None: model_path = os.path.join(path_prefix, "__model__")