未验证 提交 9cb7c839 编写于 作者: J Jiawei Wang 提交者: GitHub

Update paddle_io.py

上级 0d32c5f1
......@@ -519,299 +519,4 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
# serialize and save params
params_bytes = _serialize_persistables(program, executor)
save_to_file(params_path, params_bytes)
@static_only
def deserialize_program(data):
"""
:api_attr: Static Graph
Deserialize given data to a program.
Args:
data(bytes): serialized program.
Returns:
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
"""
program = Program.parse_from_string(data)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
return program
@static_only
def deserialize_persistables(program, data, executor):
"""
:api_attr: Static Graph
Deserialize given data to parameters according to given program and executor.
Args:
program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters.
executor(Executor): executor used to run load op.
Returns:
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
"""
if not isinstance(program, Program):
raise TypeError(
"program type must be `fluid.Program`, but received `%s`" %
type(program))
# load params to a tmp program
load_program = Program()
load_block = load_program.global_block()
vars_ = list(filter(is_persistable, program.list_vars()))
origin_shape_map = {}
load_var_map = {}
check_vars = []
sparse_vars = []
for var in vars_:
assert isinstance(var, Variable)
if var.type == core.VarDesc.VarType.RAW:
continue
if isinstance(var, Parameter):
origin_shape_map[var.name] = tuple(var.desc.get_shape())
if var.type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_vars.append(var)
continue
var_copy = _clone_var_in_block(load_block, var)
check_vars.append(var)
load_var_map[var_copy.name] = var_copy
# append load_combine op to load parameters,
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
# if load from memory, file_path is data
attrs={'file_path': data,
'model_from_memory': True})
executor.run(load_program)
# check var shape
for var in check_vars:
if not isinstance(var, Parameter):
continue
var_tmp = paddle.fluid.global_scope().find_var(var.name)
assert var_tmp != None, "can't not find var: " + var.name
new_shape = (np.array(var_tmp.get_tensor())).shape
assert var.name in origin_shape_map, var.name + " MUST in var list."
origin_shape = origin_shape_map.get(var.name)
if new_shape != origin_shape:
raise RuntimeError(
"Shape mismatch, program needs a parameter with shape ({}), "
"but the loaded parameter ('{}') has a shape of ({}).".format(
origin_shape, var.name, new_shape))
def load_from_file(path):
"""
Load file in binary mode.
Args:
path(str): Path of an existed file.
Returns:
bytes: Content of file.
"""
with open(path, 'rb') as f:
data = f.read()
return data
@static_only
def load_inference_model(path_prefix, executor, **kwargs):
"""
:api_attr: Static Graph
Load inference model from a given path. By this API, you can get the model
structure(Inference Program) and model parameters.
Args:
path_prefix(str | None): One of the following:
- Directory path to save model + model name without suffix.
- Set to None when reading the model from memory.
executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it.
kwargs: Supported keys including 'model_filename', 'params_filename'.Attention please, kwargs is used for backward compatibility mainly.
- model_filename(str): specify model_filename if you don't want to use default name.
- params_filename(str): specify params_filename if you don't want to use default name.
Returns:
list: The return of this API is a list with three elements:
(program, feed_target_names, fetch_targets). The `program` is a
``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
The `feed_target_names` is a list of ``str``, which contains names of variables
that need to feed data in the inference program. The `fetch_targets` is a list of
``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
we can get inference results.
Raises:
ValueError: If `path_prefix.pdmodel` or `path_prefix.pdiparams` doesn't exist.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_static()
# Build the model
startup_prog = paddle.static.default_startup_program()
main_prog = paddle.static.default_main_program()
with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(name="img", shape=[64, 784])
w = paddle.create_parameter(shape=[784, 200], dtype='float32')
b = paddle.create_parameter(shape=[200], dtype='float32')
hidden_w = paddle.matmul(x=image, y=w)
hidden_b = paddle.add(hidden_w, b)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_prog)
# Save the inference model
path_prefix = "./infer_model"
paddle.static.save_inference_model(path_prefix, [image], [hidden_b], exe)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(path_prefix, exe))
tensor_img = np.array(np.random.random((64, 784)), dtype=np.float32)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# In this example, the inference program was saved in file
# "./infer_model.pdmodel" and parameters were saved in file
# " ./infer_model.pdiparams".
# By the inference program, feed_target_names and
# fetch_targets, we can use an executor to run the inference
# program to get the inference result.
"""
# check kwargs
supported_args = ('model_filename', 'params_filename')
deprecated_args = ('pserver_endpoints', )
caller = inspect.currentframe().f_code.co_name
_check_args(caller, kwargs, supported_args, deprecated_args)
# load from memory
if path_prefix is None:
_logger.warning("Load inference model from memory is deprecated.")
model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None)
if params_filename is None:
raise ValueError(
"params_filename cannot be None when path_prefix is None.")
load_dirname = ''
program_bytes = model_filename
params_bytes = params_filename
# load from file
else:
# check and norm path_prefix
path_prefix = _normalize_path_prefix(path_prefix)
# set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case.
if not kwargs:
model_path = path_prefix + ".pdmodel"
params_path = path_prefix + ".pdiparams"
# set model_path and params_path in old way for compatible,
# path_prefix represents a directory path.
else:
model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None)
# set model_path
if model_filename is None:
model_path = os.path.join(path_prefix, "__model__")
else:
model_path = os.path.join(path_prefix,
model_filename + ".pdmodel")
if not os.path.exists(model_path):
model_path = os.path.join(path_prefix, model_filename)
# set params_path
if params_filename is None:
params_path = os.path.join(path_prefix, "")
else:
params_path = os.path.join(path_prefix,
params_filename + ".pdiparams")
if not os.path.exists(params_path):
params_path = os.path.join(path_prefix, params_filename)
_logger.warning("The old way to load inference model is deprecated."
" model path: {}, params path: {}".format(
model_path, params_path))
program_bytes = load_from_file(model_path)
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
# load params data
params_path = os.path.join(load_dirname, params_filename)
params_bytes = load_from_file(params_path)
# deserialize bytes to program
program = deserialize_program(program_bytes)
# deserialize bytes to params
deserialize_persistables(program, params_bytes, executor)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
return [program, feed_target_names, fetch_targets]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册