未验证 提交 db412585 编写于 作者: S Shibo Tao 提交者: GitHub

add API serialize_program, serialize_persistables, save_to_file,...

add API serialize_program, serialize_persistables, save_to_file, deserialize_program, deserialize_persistables, load_from_file. (#29034)
上级 14013a2e
...@@ -281,7 +281,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -281,7 +281,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
version, 0U, version, 0U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.", "Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version)); version));
} }
{ {
...@@ -307,7 +308,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -307,7 +308,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
version, 0U, version, 0U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.", "Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version)); version));
} }
{ {
......
...@@ -226,8 +226,8 @@ class TestSaveInferenceModelNew(unittest.TestCase): ...@@ -226,8 +226,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
'y': tensor_y}, 'y': tensor_y},
fetch_list=[avg_cost]) fetch_list=[avg_cost])
self.assertRaises(ValueError, paddle.static.save_inference_model, self.assertRaises(ValueError, paddle.static.save_inference_model, None,
None, ['x', 'y'], [avg_cost], exe) ['x', 'y'], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model, self.assertRaises(ValueError, paddle.static.save_inference_model,
MODEL_DIR + "/", [x, y], [avg_cost], exe) MODEL_DIR + "/", [x, y], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model, self.assertRaises(ValueError, paddle.static.save_inference_model,
...@@ -251,7 +251,8 @@ class TestSaveInferenceModelNew(unittest.TestCase): ...@@ -251,7 +251,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) MODEL_DIR + "_isdir", [x, y], [avg_cost], exe)
os.rmdir(params_path) os.rmdir(params_path)
paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], exe) paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost],
exe)
self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel")) self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel"))
self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams")) self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams"))
...@@ -263,20 +264,34 @@ class TestSaveInferenceModelNew(unittest.TestCase): ...@@ -263,20 +264,34 @@ class TestSaveInferenceModelNew(unittest.TestCase):
six.moves.reload_module(executor) # reload to build a new scope six.moves.reload_module(executor) # reload to build a new scope
self.assertRaises(ValueError, paddle.static.load_inference_model, self.assertRaises(ValueError, paddle.static.load_inference_model, None,
None, exe) exe)
self.assertRaises(ValueError, paddle.static.load_inference_model, self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR + "/", exe) MODEL_DIR + "/", exe)
self.assertRaises(ValueError, paddle.static.load_inference_model, self.assertRaises(ValueError, paddle.static.load_inference_model,
[MODEL_DIR], exe) [MODEL_DIR], exe)
self.assertRaises(ValueError, paddle.static.load_inference_model, self.assertRaises(
MODEL_DIR, exe, pserver_endpoints=None) ValueError,
self.assertRaises(ValueError, paddle.static.load_inference_model, paddle.static.load_inference_model,
MODEL_DIR, exe, unsupported_param=None) MODEL_DIR,
self.assertRaises((TypeError, ValueError), paddle.static.load_inference_model, exe,
None, exe, model_filename="illegal", params_filename="illegal") pserver_endpoints=None)
self.assertRaises(
model = InferModel(paddle.static.io.load_inference_model(MODEL_DIR, exe)) ValueError,
paddle.static.load_inference_model,
MODEL_DIR,
exe,
unsupported_param=None)
self.assertRaises(
(TypeError, ValueError),
paddle.static.load_inference_model,
None,
exe,
model_filename="illegal",
params_filename="illegal")
model = InferModel(
paddle.static.io.load_inference_model(MODEL_DIR, exe))
outs = exe.run(model.program, outs = exe.run(model.program,
feed={ feed={
...@@ -289,7 +304,57 @@ class TestSaveInferenceModelNew(unittest.TestCase): ...@@ -289,7 +304,57 @@ class TestSaveInferenceModelNew(unittest.TestCase):
self.assertEqual(model.feed_var_names, ["x", "y"]) self.assertEqual(model.feed_var_names, ["x", "y"])
self.assertEqual(len(model.fetch_vars), 1) self.assertEqual(len(model.fetch_vars), 1)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
# test save_to_file content type should be bytes
self.assertRaises(ValueError, paddle.static.io.save_to_file, '', 123)
# test _get_valid_program
self.assertRaises(TypeError, paddle.static.io._get_valid_program, 0)
p = Program()
cp = CompiledProgram(p)
paddle.static.io._get_valid_program(cp)
self.assertTrue(paddle.static.io._get_valid_program(cp) is p)
cp._program = None
self.assertRaises(TypeError, paddle.static.io._get_valid_program, cp)
def test_serialize_program_and_persistables(self):
init_program = fluid.default_startup_program()
program = fluid.default_main_program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost, init_program)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7]]).astype("float32")
for i in six.moves.xrange(3):
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
# test if return type of serialize_program is bytes
res1 = paddle.static.io.serialize_program([x, y], [avg_cost])
self.assertTrue(isinstance(res1, bytes))
# test if return type of serialize_persistables is bytes
res2 = paddle.static.io.serialize_persistables([x, y], [avg_cost], exe)
self.assertTrue(isinstance(res2, bytes))
# test if variables in program is empty
res = paddle.static.io._serialize_persistables(Program(), None)
self.assertEqual(res, None)
self.assertRaises(TypeError, paddle.static.io.deserialize_persistables,
None, None, None)
class TestLoadInferenceModelError(unittest.TestCase): class TestLoadInferenceModelError(unittest.TestCase):
......
...@@ -18,28 +18,43 @@ import errno ...@@ -18,28 +18,43 @@ import errno
import inspect import inspect
import logging import logging
import os import os
import warnings
import six import six
import numpy as np
import paddle import paddle
from paddle.fluid import core, Variable, CompiledProgram, program_guard, default_main_program, Program from paddle.fluid import (
from paddle.fluid.framework import static_only core,
from paddle.fluid import layers Variable,
CompiledProgram,
from paddle.fluid.io import _get_valid_program, save_vars, _save_distributed_persistables default_main_program,
from paddle.fluid.io import prepend_feed_ops, append_fetch_ops, save_persistables Program,
from paddle.fluid.io import load_persistables, _endpoints_replacement layers,
unique_name,
program_guard, )
from paddle.fluid.io import prepend_feed_ops, append_fetch_ops
from paddle.fluid.framework import static_only, Parameter
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
__all__ = [ __all__ = [
'save_inference_model', 'save_inference_model',
'load_inference_model', 'load_inference_model',
'serialize_program',
'serialize_persistables',
'save_to_file',
'deserialize_program',
'deserialize_persistables',
'load_from_file',
] ]
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _check_args(caller, args, supported_args=[], deprecated_args=[]): def _check_args(caller, args, supported_args=None, deprecated_args=None):
supported_args = [] if supported_args is None else supported_args
deprecated_args = [] if deprecated_args is None else deprecated_args
for arg in args: for arg in args:
if arg in deprecated_args: if arg in deprecated_args:
raise ValueError( raise ValueError(
...@@ -51,6 +66,319 @@ def _check_args(caller, args, supported_args=[], deprecated_args=[]): ...@@ -51,6 +66,319 @@ def _check_args(caller, args, supported_args=[], deprecated_args=[]):
format(caller, arg, supported_args)) format(caller, arg, supported_args))
def _check_vars(name, var_list):
if not isinstance(var_list, list):
var_list = [var_list]
if not var_list or not all([isinstance(var, Variable) for var in var_list]):
raise ValueError(
"'{}' should be a Variable or a list of Variable.".format(name))
def _normalize_path_prefix(path_prefix):
"""
convert path_prefix to absolute path.
"""
if not isinstance(path_prefix, six.string_types):
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
return path_prefix
def _get_valid_program(program=None):
"""
return default main program if program is None.
"""
if program is None:
program = default_main_program()
elif isinstance(program, CompiledProgram):
program = program._program
if program is None:
raise TypeError(
"The type of input program is invalid, expected tyep is Program, but received None"
)
warnings.warn(
"The input is a CompiledProgram, this is not recommended.")
if not isinstance(program, Program):
raise TypeError(
"The type of input program is invalid, expected type is fluid.Program, but received %s"
% type(program))
return program
def _clone_var_in_block(block, var):
assert isinstance(var, Variable)
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def _normalize_program(program, feed_vars, fetch_vars):
"""
optimize program according feed_vars and fetch_vars.
"""
# remind users to set auc_states to 0 if auc op were found.
for op in program.global_block().ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn("Be sure that you have set auc states to 0 "
"before saving inference model.")
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(program):
uniq_fetch_vars = []
for i, var in enumerate(fetch_vars):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_fetch_vars.append(var)
fetch_vars = uniq_fetch_vars
# serialize program
copy_program = program.clone()
global_block = copy_program.global_block()
remove_op_idx = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
remove_op_idx.append(i)
for idx in remove_op_idx[::-1]:
global_block._remove_op(idx)
copy_program.desc.flush()
feed_var_names = [var.name for var in feed_vars]
copy_program = copy_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars)
copy_program = copy_program._inference_optimize(prune_read_op=True)
fetch_var_names = [var.name for var in fetch_vars]
prepend_feed_ops(copy_program, feed_var_names)
append_fetch_ops(copy_program, fetch_var_names)
copy_program.desc._set_version()
return copy_program
def is_persistable(var):
"""
Check whether the given variable is persistable.
Args:
var(Variable): The variable to be checked.
Returns:
bool: True if the given `var` is persistable
False if not.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
param = fluid.default_main_program().global_block().var('fc.b')
res = fluid.io.is_persistable(param)
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
@static_only
def serialize_program(feed_vars, fetch_vars):
"""
:api_attr: Static Graph
Serialize default main program according to feed_vars and fetch_vars.
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
bytes: serialized program.
Raises:
ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
"""
# verify feed_vars
_check_vars('feed_vars', feed_vars)
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)
def _serialize_program(program):
"""
serialize given program to bytes.
"""
return program.desc.serialize_to_string()
@static_only
def serialize_persistables(feed_vars, fetch_vars, executor):
"""
:api_attr: Static Graph
Serialize parameters using given executor and default main program according to feed_vars and fetch_vars.
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
bytes: serialized program.
Raises:
ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
"""
# verify feed_vars
_check_vars('feed_vars', feed_vars)
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor)
def _serialize_persistables(program, executor):
"""
Serialize parameters using given program and executor.
"""
vars_ = list(filter(is_persistable, program.list_vars()))
# warn if no variable found in model
if len(vars_) == 0:
warnings.warn("no variable in your model, please ensure there are any "
"variables in your model to save")
return None
# create a new program and clone persitable vars to it
save_program = Program()
save_block = save_program.global_block()
save_var_map = {}
for var in vars_:
if var.type != core.VarDesc.VarType.RAW:
var_copy = _clone_var_in_block(save_block, var)
save_var_map[var_copy.name] = var
# create in_vars and out_var, then append a save_combine op to save_program
in_vars = []
for name in sorted(save_var_map.keys()):
in_vars.append(save_var_map[name])
out_var_name = unique_name.generate("out_var")
out_var = save_block.create_var(
type=core.VarDesc.VarType.RAW, name=out_var_name)
out_var.desc.set_persistable(True)
save_block.append_op(
type='save_combine',
inputs={'X': in_vars},
outputs={'Y': out_var},
attrs={'file_path': '',
'save_to_memory': True})
# run save_program to save vars
# NOTE(zhiqiu): save op will add variable kLookupTablePath to save_program.desc,
# which leads to diff between save_program and its desc. Call _sync_with_cpp
# to keep consistency.
save_program._sync_with_cpp()
executor.run(save_program)
# return serialized bytes in out_var
return global_scope().find_var(out_var_name).get_bytes()
def save_to_file(path, content):
"""
Save content to given path.
Args:
path(str): Path to write content to.
content(bytes): Content to write.
Returns:
None
"""
if not isinstance(content, bytes):
raise ValueError("'content' type should be bytes.")
with open(path, "wb") as f:
f.write(content)
@static_only @static_only
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
""" """
...@@ -106,13 +434,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): ...@@ -106,13 +434,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
# and parameters are going to be saved in file "./infer_model.pdiparams". # and parameters are going to be saved in file "./infer_model.pdiparams".
""" """
# check path_prefix, set model_path and params_path # check path_prefix, set model_path and params_path
if not isinstance(path_prefix, six.string_types): path_prefix = _normalize_path_prefix(path_prefix)
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
try: try:
# mkdir may conflict if pserver and trainer are running on the same machine # mkdir may conflict if pserver and trainer are running on the same machine
dirname = os.path.dirname(path_prefix) dirname = os.path.dirname(path_prefix)
...@@ -128,74 +452,118 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): ...@@ -128,74 +452,118 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
raise ValueError("'{}' is an existing directory.".format(params_path)) raise ValueError("'{}' is an existing directory.".format(params_path))
# verify feed_vars # verify feed_vars
if not isinstance(feed_vars, list): _check_vars('feed_vars', feed_vars)
feed_vars = [feed_vars]
if not feed_vars or not all(
[isinstance(var, Variable) for var in feed_vars]):
raise ValueError(
"'feed_vars' should be a Variable or a list of Variable.")
# verify fetch_vars # verify fetch_vars
if not isinstance(fetch_vars, list): _check_vars('fetch_vars', fetch_vars)
fetch_vars = [fetch_vars]
if not fetch_vars or not all(
[isinstance(var, Variable) for var in fetch_vars]):
raise ValueError(
"'fetch_vars' should be a Variable or a list of Variable.")
main_program = _get_valid_program() program = _get_valid_program()
# remind users to set auc_states to 0 if auc op were found. program = _normalize_program(program, feed_vars, fetch_vars)
for op in main_program.global_block().ops: # serialize and save program
# clear device of Op program_bytes = _serialize_program(program)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() save_to_file(model_path, program_bytes)
op._set_attr(device_attr_name, "") # serialize and save params
if op.type == 'auc': params_bytes = _serialize_persistables(program, executor)
warnings.warn( save_to_file(params_path, params_bytes)
"Be sure that you have set auc states to 0 before saving inference model."
)
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_fetch_vars = []
for i, var in enumerate(fetch_vars):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_fetch_vars.append(var)
fetch_vars = uniq_fetch_vars
# save model @static_only
origin_program = main_program.clone() def deserialize_program(data):
main_program = main_program.clone() """
global_block = main_program.global_block() :api_attr: Static Graph
remove_op_idx = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
remove_op_idx.append(i)
for idx in remove_op_idx[::-1]:
global_block._remove_op(idx)
main_program.desc.flush()
feed_var_names = [var.name for var in feed_vars] Deserialize given data to a program.
main_program = main_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [var.name for var in fetch_vars]
prepend_feed_ops(main_program, feed_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_path, "wb") as f:
f.write(main_program.desc.serialize_to_string())
main_program._copy_dist_param_info_from(origin_program)
# save params Args:
dirname = os.path.dirname(params_path) data(bytes): serialized program.
basename = os.path.basename(params_path) Returns:
save_persistables(executor, dirname, main_program, basename) Program: deserialized program.
"""
program = Program.parse_from_string(data)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
return program
@static_only
def deserialize_persistables(program, data, executor):
"""
:api_attr: Static Graph
Deserialize given data to parameters according to given program and executor.
Args:
program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters.
executor(Executor): executor used to run load op.
Returns:
Program: deserialized program.
"""
if not isinstance(program, Program):
raise TypeError(
"program type must be `fluid.Program`, but received `%s`" %
type(program))
# load params to a tmp program
load_program = Program()
load_block = load_program.global_block()
vars_ = list(filter(is_persistable, program.list_vars()))
origin_shape_map = {}
load_var_map = {}
check_vars = []
sparse_vars = []
for var in vars_:
assert isinstance(var, Variable)
if var.type == core.VarDesc.VarType.RAW:
continue
if isinstance(var, Parameter):
origin_shape_map[var.name] = tuple(var.desc.get_shape())
if var.type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_vars.append(var)
continue
var_copy = _clone_var_in_block(load_block, var)
check_vars.append(var)
load_var_map[var_copy.name] = var_copy
# append load_combine op to load parameters,
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
# if load from memory, file_path is data
attrs={'file_path': data,
'model_from_memory': True})
executor.run(load_program)
# check var shape
for var in check_vars:
if not isinstance(var, Parameter):
continue
var_tmp = paddle.fluid.global_scope().find_var(var.name)
assert var_tmp != None, "can't not find var: " + var.name
new_shape = (np.array(var_tmp.get_tensor())).shape
assert var.name in origin_shape_map, var.name + " MUST in var list."
origin_shape = origin_shape_map.get(var.name)
if new_shape != origin_shape:
raise RuntimeError(
"Shape mismatch, program needs a parameter with shape ({}), "
"but the loaded parameter ('{}') has a shape of ({}).".format(
origin_shape, var.name, new_shape))
def load_from_file(path):
"""
Load file in binary mode.
Args:
path(str): Path of an existed file.
Returns:
bytes: Content of file.
"""
with open(path, 'rb') as f:
data = f.read()
return data
@static_only @static_only
...@@ -277,18 +645,13 @@ def load_inference_model(path_prefix, executor, **configs): ...@@ -277,18 +645,13 @@ def load_inference_model(path_prefix, executor, **configs):
if params_filename is None: if params_filename is None:
raise ValueError( raise ValueError(
"params_filename cannot be None when path_prefix is None.") "params_filename cannot be None when path_prefix is None.")
load_dirname = path_prefix load_dirname = ''
program_desc_str = model_filename program_bytes = model_filename
params_filename = params_filename params_filename = params_filename
# load from file # load from file
else: else:
# check and norm path_prefix # check and norm path_prefix
if not isinstance(path_prefix, six.string_types): path_prefix = _normalize_path_prefix(path_prefix)
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
# set model_path and params_path in new way, # set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case. # path_prefix represents a file path without suffix in this case.
...@@ -319,17 +682,17 @@ def load_inference_model(path_prefix, executor, **configs): ...@@ -319,17 +682,17 @@ def load_inference_model(path_prefix, executor, **configs):
_logger.warning("The old way to load inference model is deprecated." _logger.warning("The old way to load inference model is deprecated."
" model path: {}, params path: {}".format( " model path: {}, params path: {}".format(
model_path, params_path)) model_path, params_path))
with open(model_path, "rb") as f: program_bytes = load_from_file(model_path)
program_desc_str = f.read()
load_dirname = os.path.dirname(params_path) load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path) params_filename = os.path.basename(params_path)
program = Program.parse_from_string(program_desc_str) # deserialize bytes to program
if not core._is_program_version_supported(program._version()): program = deserialize_program(program_bytes)
raise ValueError("Unsupported program version: %d\n" % # load params data
program._version()) params_path = os.path.join(load_dirname, params_filename)
# Binary data also need versioning. params_bytes = load_from_file(params_path)
load_persistables(executor, load_dirname, program, params_filename) # deserialize bytes to params
deserialize_persistables(program, params_bytes, executor)
feed_target_names = program.desc.get_feed_target_names() feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names() fetch_target_names = program.desc.get_fetch_target_names()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册