From db412585013cde98de1968f102625da46528254d Mon Sep 17 00:00:00 2001 From: Shibo Tao <62922815+T8T9@users.noreply.github.com> Date: Thu, 26 Nov 2020 14:08:10 +0800 Subject: [PATCH] add API serialize_program, serialize_persistables, save_to_file, deserialize_program, deserialize_persistables, load_from_file. (#29034) --- paddle/fluid/framework/lod_tensor.cc | 6 +- .../unittests/test_inference_model_io.py | 109 +++- python/paddle/static/io.py | 547 +++++++++++++++--- 3 files changed, 546 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index a044812dd3..a82be2acb3 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -281,7 +281,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, PADDLE_ENFORCE_EQ( version, 0U, platform::errors::InvalidArgument( - "Tensor version %u is not supported, only version 0 is supported.", + "Deserialize to tensor failed, maybe the loaded file is " + "not a paddle model(expected file format: 0, but %u found).", version)); } { @@ -307,7 +308,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, PADDLE_ENFORCE_EQ( version, 0U, platform::errors::InvalidArgument( - "Tensor version %u is not supported, only version 0 is supported.", + "Deserialize to tensor failed, maybe the loaded file is " + "not a paddle model(expected file format: 0, but %u found).", version)); } { diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index a82bc3f0f6..9a5d0b3e9b 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -226,32 +226,33 @@ class TestSaveInferenceModelNew(unittest.TestCase): 'y': tensor_y}, fetch_list=[avg_cost]) + self.assertRaises(ValueError, paddle.static.save_inference_model, None, + ['x', 'y'], [avg_cost], exe) self.assertRaises(ValueError, paddle.static.save_inference_model, - None, ['x', 'y'], [avg_cost], exe) + MODEL_DIR + "/", [x, y], [avg_cost], exe) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR + "/", [x, y], [avg_cost], exe) + MODEL_DIR, ['x', 'y'], [avg_cost], exe) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR, ['x', 'y'], [avg_cost], exe) + MODEL_DIR, 'x', [avg_cost], exe) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR, 'x', [avg_cost], exe) + MODEL_DIR, [x, y], ['avg_cost'], exe) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR, [x, y], ['avg_cost'], exe) - self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR, [x, y], 'avg_cost', exe) + MODEL_DIR, [x, y], 'avg_cost', exe) model_path = MODEL_DIR + "_isdir.pdmodel" os.makedirs(model_path) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) + MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) os.rmdir(model_path) params_path = MODEL_DIR + "_isdir.pdmodel" os.makedirs(params_path) self.assertRaises(ValueError, paddle.static.save_inference_model, - MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) + MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) os.rmdir(params_path) - paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], exe) + paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], + exe) self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel")) self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams")) @@ -263,20 +264,34 @@ class TestSaveInferenceModelNew(unittest.TestCase): six.moves.reload_module(executor) # reload to build a new scope + self.assertRaises(ValueError, paddle.static.load_inference_model, None, + exe) self.assertRaises(ValueError, paddle.static.load_inference_model, - None, exe) - self.assertRaises(ValueError, paddle.static.load_inference_model, - MODEL_DIR + "/", exe) - self.assertRaises(ValueError, paddle.static.load_inference_model, - [MODEL_DIR], exe) - self.assertRaises(ValueError, paddle.static.load_inference_model, - MODEL_DIR, exe, pserver_endpoints=None) + MODEL_DIR + "/", exe) self.assertRaises(ValueError, paddle.static.load_inference_model, - MODEL_DIR, exe, unsupported_param=None) - self.assertRaises((TypeError, ValueError), paddle.static.load_inference_model, - None, exe, model_filename="illegal", params_filename="illegal") - - model = InferModel(paddle.static.io.load_inference_model(MODEL_DIR, exe)) + [MODEL_DIR], exe) + self.assertRaises( + ValueError, + paddle.static.load_inference_model, + MODEL_DIR, + exe, + pserver_endpoints=None) + self.assertRaises( + ValueError, + paddle.static.load_inference_model, + MODEL_DIR, + exe, + unsupported_param=None) + self.assertRaises( + (TypeError, ValueError), + paddle.static.load_inference_model, + None, + exe, + model_filename="illegal", + params_filename="illegal") + + model = InferModel( + paddle.static.io.load_inference_model(MODEL_DIR, exe)) outs = exe.run(model.program, feed={ @@ -289,7 +304,57 @@ class TestSaveInferenceModelNew(unittest.TestCase): self.assertEqual(model.feed_var_names, ["x", "y"]) self.assertEqual(len(model.fetch_vars), 1) self.assertEqual(expected, actual) + # test save_to_file content type should be bytes + self.assertRaises(ValueError, paddle.static.io.save_to_file, '', 123) + # test _get_valid_program + self.assertRaises(TypeError, paddle.static.io._get_valid_program, 0) + p = Program() + cp = CompiledProgram(p) + paddle.static.io._get_valid_program(cp) + self.assertTrue(paddle.static.io._get_valid_program(cp) is p) + cp._program = None + self.assertRaises(TypeError, paddle.static.io._get_valid_program, cp) + + def test_serialize_program_and_persistables(self): + init_program = fluid.default_startup_program() + program = fluid.default_main_program() + + # fake program without feed/fetch + with program_guard(program, init_program): + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[1], dtype='float32') + + y_predict = layers.fc(input=x, size=1, act=None) + + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(cost) + + sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost, init_program) + + place = core.CPUPlace() + exe = executor.Executor(place) + exe.run(init_program, feed={}, fetch_list=[]) + + tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32") + tensor_y = np.array([[-2], [-3], [-7]]).astype("float32") + for i in six.moves.xrange(3): + exe.run(program, + feed={'x': tensor_x, + 'y': tensor_y}, + fetch_list=[avg_cost]) + # test if return type of serialize_program is bytes + res1 = paddle.static.io.serialize_program([x, y], [avg_cost]) + self.assertTrue(isinstance(res1, bytes)) + # test if return type of serialize_persistables is bytes + res2 = paddle.static.io.serialize_persistables([x, y], [avg_cost], exe) + self.assertTrue(isinstance(res2, bytes)) + # test if variables in program is empty + res = paddle.static.io._serialize_persistables(Program(), None) + self.assertEqual(res, None) + self.assertRaises(TypeError, paddle.static.io.deserialize_persistables, + None, None, None) class TestLoadInferenceModelError(unittest.TestCase): diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 84a5ed9950..cfaa6d9470 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -18,28 +18,43 @@ import errno import inspect import logging import os +import warnings import six +import numpy as np import paddle -from paddle.fluid import core, Variable, CompiledProgram, program_guard, default_main_program, Program -from paddle.fluid.framework import static_only -from paddle.fluid import layers - -from paddle.fluid.io import _get_valid_program, save_vars, _save_distributed_persistables -from paddle.fluid.io import prepend_feed_ops, append_fetch_ops, save_persistables -from paddle.fluid.io import load_persistables, _endpoints_replacement +from paddle.fluid import ( + core, + Variable, + CompiledProgram, + default_main_program, + Program, + layers, + unique_name, + program_guard, ) +from paddle.fluid.io import prepend_feed_ops, append_fetch_ops +from paddle.fluid.framework import static_only, Parameter +from paddle.fluid.executor import Executor, global_scope from paddle.fluid.log_helper import get_logger __all__ = [ 'save_inference_model', 'load_inference_model', + 'serialize_program', + 'serialize_persistables', + 'save_to_file', + 'deserialize_program', + 'deserialize_persistables', + 'load_from_file', ] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -def _check_args(caller, args, supported_args=[], deprecated_args=[]): +def _check_args(caller, args, supported_args=None, deprecated_args=None): + supported_args = [] if supported_args is None else supported_args + deprecated_args = [] if deprecated_args is None else deprecated_args for arg in args: if arg in deprecated_args: raise ValueError( @@ -51,6 +66,319 @@ def _check_args(caller, args, supported_args=[], deprecated_args=[]): format(caller, arg, supported_args)) +def _check_vars(name, var_list): + if not isinstance(var_list, list): + var_list = [var_list] + if not var_list or not all([isinstance(var, Variable) for var in var_list]): + raise ValueError( + "'{}' should be a Variable or a list of Variable.".format(name)) + + +def _normalize_path_prefix(path_prefix): + """ + convert path_prefix to absolute path. + """ + if not isinstance(path_prefix, six.string_types): + raise ValueError("'path_prefix' should be a string.") + if path_prefix.endswith("/"): + raise ValueError("'path_prefix' should not be a directory") + path_prefix = os.path.normpath(path_prefix) + path_prefix = os.path.abspath(path_prefix) + return path_prefix + + +def _get_valid_program(program=None): + """ + return default main program if program is None. + """ + if program is None: + program = default_main_program() + elif isinstance(program, CompiledProgram): + program = program._program + if program is None: + raise TypeError( + "The type of input program is invalid, expected tyep is Program, but received None" + ) + warnings.warn( + "The input is a CompiledProgram, this is not recommended.") + if not isinstance(program, Program): + raise TypeError( + "The type of input program is invalid, expected type is fluid.Program, but received %s" + % type(program)) + return program + + +def _clone_var_in_block(block, var): + assert isinstance(var, Variable) + if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + else: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + persistable=True) + + +def _normalize_program(program, feed_vars, fetch_vars): + """ + optimize program according feed_vars and fetch_vars. + """ + # remind users to set auc_states to 0 if auc op were found. + for op in program.global_block().ops: + # clear device of Op + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + op._set_attr(device_attr_name, "") + if op.type == 'auc': + warnings.warn("Be sure that you have set auc states to 0 " + "before saving inference model.") + break + + # fix the bug that the activation op's output as target will be pruned. + # will affect the inference performance. + # TODO(Superjomn) add an IR pass to remove 1-scale op. + with program_guard(program): + uniq_fetch_vars = [] + for i, var in enumerate(fetch_vars): + var = layers.scale( + var, 1., name="save_infer_model/scale_{}".format(i)) + uniq_fetch_vars.append(var) + fetch_vars = uniq_fetch_vars + + # serialize program + copy_program = program.clone() + global_block = copy_program.global_block() + remove_op_idx = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + remove_op_idx.append(i) + for idx in remove_op_idx[::-1]: + global_block._remove_op(idx) + copy_program.desc.flush() + + feed_var_names = [var.name for var in feed_vars] + copy_program = copy_program._prune_with_input( + feeded_var_names=feed_var_names, targets=fetch_vars) + copy_program = copy_program._inference_optimize(prune_read_op=True) + fetch_var_names = [var.name for var in fetch_vars] + prepend_feed_ops(copy_program, feed_var_names) + append_fetch_ops(copy_program, fetch_var_names) + copy_program.desc._set_version() + return copy_program + + +def is_persistable(var): + """ + Check whether the given variable is persistable. + + Args: + var(Variable): The variable to be checked. + + Returns: + bool: True if the given `var` is persistable + False if not. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + + paddle.enable_static() + param = fluid.default_main_program().global_block().var('fc.b') + res = fluid.io.is_persistable(param) + """ + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.READER: + return False + return var.persistable + + +@static_only +def serialize_program(feed_vars, fetch_vars): + """ + :api_attr: Static Graph + + Serialize default main program according to feed_vars and fetch_vars. + + Args: + feed_vars(Variable | list[Variable]): Variables needed by inference. + fetch_vars(Variable | list[Variable]): Variables returned by inference. + Returns: + bytes: serialized program. + + Raises: + ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown. + ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = paddle.static.nn.fc(image, 10, activation='softmax') + + loss = paddle.nn.functional.cross_entropy(predict, label) + avg_loss = paddle.tensor.stat.mean(loss) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + # serialize the default main program to bytes. + serialized_program = paddle.static.serialize_program([image], [predict]) + + # deserialize bytes to program + deserialized_program = paddle.static.deserialize_program(serialized_program) + + """ + # verify feed_vars + _check_vars('feed_vars', feed_vars) + # verify fetch_vars + _check_vars('fetch_vars', fetch_vars) + + program = _get_valid_program() + program = _normalize_program(program, feed_vars, fetch_vars) + return _serialize_program(program) + + +def _serialize_program(program): + """ + serialize given program to bytes. + """ + return program.desc.serialize_to_string() + + +@static_only +def serialize_persistables(feed_vars, fetch_vars, executor): + """ + :api_attr: Static Graph + + Serialize parameters using given executor and default main program according to feed_vars and fetch_vars. + + Args: + feed_vars(Variable | list[Variable]): Variables needed by inference. + fetch_vars(Variable | list[Variable]): Variables returned by inference. + Returns: + bytes: serialized program. + + Raises: + ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown. + ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = paddle.static.nn.fc(image, 10, activation='softmax') + + loss = paddle.nn.functional.cross_entropy(predict, label) + avg_loss = paddle.tensor.stat.mean(loss) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + # serialize parameters to bytes. + serialized_params = paddle.static.serialize_persistables([image], [predict], exe) + + # deserialize bytes to parameters. + main_program = paddle.static.default_main_program() + deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe) + + """ + # verify feed_vars + _check_vars('feed_vars', feed_vars) + # verify fetch_vars + _check_vars('fetch_vars', fetch_vars) + + program = _get_valid_program() + program = _normalize_program(program, feed_vars, fetch_vars) + return _serialize_persistables(program, executor) + + +def _serialize_persistables(program, executor): + """ + Serialize parameters using given program and executor. + """ + vars_ = list(filter(is_persistable, program.list_vars())) + # warn if no variable found in model + if len(vars_) == 0: + warnings.warn("no variable in your model, please ensure there are any " + "variables in your model to save") + return None + # create a new program and clone persitable vars to it + save_program = Program() + save_block = save_program.global_block() + save_var_map = {} + for var in vars_: + if var.type != core.VarDesc.VarType.RAW: + var_copy = _clone_var_in_block(save_block, var) + save_var_map[var_copy.name] = var + + # create in_vars and out_var, then append a save_combine op to save_program + in_vars = [] + for name in sorted(save_var_map.keys()): + in_vars.append(save_var_map[name]) + + out_var_name = unique_name.generate("out_var") + out_var = save_block.create_var( + type=core.VarDesc.VarType.RAW, name=out_var_name) + out_var.desc.set_persistable(True) + save_block.append_op( + type='save_combine', + inputs={'X': in_vars}, + outputs={'Y': out_var}, + attrs={'file_path': '', + 'save_to_memory': True}) + # run save_program to save vars + # NOTE(zhiqiu): save op will add variable kLookupTablePath to save_program.desc, + # which leads to diff between save_program and its desc. Call _sync_with_cpp + # to keep consistency. + save_program._sync_with_cpp() + executor.run(save_program) + # return serialized bytes in out_var + return global_scope().find_var(out_var_name).get_bytes() + + +def save_to_file(path, content): + """ + Save content to given path. + Args: + path(str): Path to write content to. + content(bytes): Content to write. + Returns: + None + """ + + if not isinstance(content, bytes): + raise ValueError("'content' type should be bytes.") + with open(path, "wb") as f: + f.write(content) + + @static_only def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): """ @@ -106,13 +434,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): # and parameters are going to be saved in file "./infer_model.pdiparams". """ + # check path_prefix, set model_path and params_path - if not isinstance(path_prefix, six.string_types): - raise ValueError("'path_prefix' should be a string.") - if path_prefix.endswith("/"): - raise ValueError("'path_prefix' should not be a directory") - path_prefix = os.path.normpath(path_prefix) - path_prefix = os.path.abspath(path_prefix) + path_prefix = _normalize_path_prefix(path_prefix) try: # mkdir may conflict if pserver and trainer are running on the same machine dirname = os.path.dirname(path_prefix) @@ -128,74 +452,118 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): raise ValueError("'{}' is an existing directory.".format(params_path)) # verify feed_vars - if not isinstance(feed_vars, list): - feed_vars = [feed_vars] - if not feed_vars or not all( - [isinstance(var, Variable) for var in feed_vars]): - raise ValueError( - "'feed_vars' should be a Variable or a list of Variable.") - + _check_vars('feed_vars', feed_vars) # verify fetch_vars - if not isinstance(fetch_vars, list): - fetch_vars = [fetch_vars] - if not fetch_vars or not all( - [isinstance(var, Variable) for var in fetch_vars]): - raise ValueError( - "'fetch_vars' should be a Variable or a list of Variable.") + _check_vars('fetch_vars', fetch_vars) - main_program = _get_valid_program() - # remind users to set auc_states to 0 if auc op were found. - for op in main_program.global_block().ops: - # clear device of Op - device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() - op._set_attr(device_attr_name, "") - if op.type == 'auc': - warnings.warn( - "Be sure that you have set auc states to 0 before saving inference model." - ) - break + program = _get_valid_program() + program = _normalize_program(program, feed_vars, fetch_vars) + # serialize and save program + program_bytes = _serialize_program(program) + save_to_file(model_path, program_bytes) + # serialize and save params + params_bytes = _serialize_persistables(program, executor) + save_to_file(params_path, params_bytes) - # fix the bug that the activation op's output as target will be pruned. - # will affect the inference performance. - # TODO(Superjomn) add an IR pass to remove 1-scale op. - with program_guard(main_program): - uniq_fetch_vars = [] - for i, var in enumerate(fetch_vars): - var = layers.scale( - var, 1., name="save_infer_model/scale_{}".format(i)) - uniq_fetch_vars.append(var) - fetch_vars = uniq_fetch_vars - # save model - origin_program = main_program.clone() - main_program = main_program.clone() - global_block = main_program.global_block() - remove_op_idx = [] - for i, op in enumerate(global_block.ops): - op.desc.set_is_target(False) - if op.type == "feed" or op.type == "fetch": - remove_op_idx.append(i) - for idx in remove_op_idx[::-1]: - global_block._remove_op(idx) - main_program.desc.flush() +@static_only +def deserialize_program(data): + """ + :api_attr: Static Graph - feed_var_names = [var.name for var in feed_vars] - main_program = main_program._prune_with_input( - feeded_var_names=feed_var_names, targets=fetch_vars) - main_program = main_program._inference_optimize(prune_read_op=True) - fetch_var_names = [var.name for var in fetch_vars] - prepend_feed_ops(main_program, feed_var_names) - append_fetch_ops(main_program, fetch_var_names) - main_program.desc._set_version() - paddle.fluid.core.save_op_version_info(main_program.desc) - with open(model_path, "wb") as f: - f.write(main_program.desc.serialize_to_string()) - main_program._copy_dist_param_info_from(origin_program) + Deserialize given data to a program. + + Args: + data(bytes): serialized program. + Returns: + Program: deserialized program. + """ + program = Program.parse_from_string(data) + if not core._is_program_version_supported(program._version()): + raise ValueError("Unsupported program version: %d\n" % + program._version()) + return program + + +@static_only +def deserialize_persistables(program, data, executor): + """ + :api_attr: Static Graph + + Deserialize given data to parameters according to given program and executor. - # save params - dirname = os.path.dirname(params_path) - basename = os.path.basename(params_path) - save_persistables(executor, dirname, main_program, basename) + Args: + program(Program): program that contains parameter names (to deserialize). + data(bytes): serialized parameters. + executor(Executor): executor used to run load op. + Returns: + Program: deserialized program. + """ + if not isinstance(program, Program): + raise TypeError( + "program type must be `fluid.Program`, but received `%s`" % + type(program)) + # load params to a tmp program + load_program = Program() + load_block = load_program.global_block() + vars_ = list(filter(is_persistable, program.list_vars())) + + origin_shape_map = {} + load_var_map = {} + check_vars = [] + sparse_vars = [] + for var in vars_: + assert isinstance(var, Variable) + if var.type == core.VarDesc.VarType.RAW: + continue + if isinstance(var, Parameter): + origin_shape_map[var.name] = tuple(var.desc.get_shape()) + if var.type == core.VarDesc.VarType.SELECTED_ROWS: + sparse_vars.append(var) + continue + var_copy = _clone_var_in_block(load_block, var) + check_vars.append(var) + load_var_map[var_copy.name] = var_copy + + # append load_combine op to load parameters, + load_var_list = [] + for name in sorted(load_var_map.keys()): + load_var_list.append(load_var_map[name]) + load_block.append_op( + type='load_combine', + inputs={}, + outputs={"Out": load_var_list}, + # if load from memory, file_path is data + attrs={'file_path': data, + 'model_from_memory': True}) + executor.run(load_program) + # check var shape + for var in check_vars: + if not isinstance(var, Parameter): + continue + var_tmp = paddle.fluid.global_scope().find_var(var.name) + assert var_tmp != None, "can't not find var: " + var.name + new_shape = (np.array(var_tmp.get_tensor())).shape + assert var.name in origin_shape_map, var.name + " MUST in var list." + origin_shape = origin_shape_map.get(var.name) + if new_shape != origin_shape: + raise RuntimeError( + "Shape mismatch, program needs a parameter with shape ({}), " + "but the loaded parameter ('{}') has a shape of ({}).".format( + origin_shape, var.name, new_shape)) + + +def load_from_file(path): + """ + Load file in binary mode. + Args: + path(str): Path of an existed file. + Returns: + bytes: Content of file. + """ + with open(path, 'rb') as f: + data = f.read() + return data @static_only @@ -277,18 +645,13 @@ def load_inference_model(path_prefix, executor, **configs): if params_filename is None: raise ValueError( "params_filename cannot be None when path_prefix is None.") - load_dirname = path_prefix - program_desc_str = model_filename + load_dirname = '' + program_bytes = model_filename params_filename = params_filename # load from file else: # check and norm path_prefix - if not isinstance(path_prefix, six.string_types): - raise ValueError("'path_prefix' should be a string.") - if path_prefix.endswith("/"): - raise ValueError("'path_prefix' should not be a directory") - path_prefix = os.path.normpath(path_prefix) - path_prefix = os.path.abspath(path_prefix) + path_prefix = _normalize_path_prefix(path_prefix) # set model_path and params_path in new way, # path_prefix represents a file path without suffix in this case. @@ -319,17 +682,17 @@ def load_inference_model(path_prefix, executor, **configs): _logger.warning("The old way to load inference model is deprecated." " model path: {}, params path: {}".format( model_path, params_path)) - with open(model_path, "rb") as f: - program_desc_str = f.read() + program_bytes = load_from_file(model_path) load_dirname = os.path.dirname(params_path) params_filename = os.path.basename(params_path) - program = Program.parse_from_string(program_desc_str) - if not core._is_program_version_supported(program._version()): - raise ValueError("Unsupported program version: %d\n" % - program._version()) - # Binary data also need versioning. - load_persistables(executor, load_dirname, program, params_filename) + # deserialize bytes to program + program = deserialize_program(program_bytes) + # load params data + params_path = os.path.join(load_dirname, params_filename) + params_bytes = load_from_file(params_path) + # deserialize bytes to params + deserialize_persistables(program, params_bytes, executor) feed_target_names = program.desc.get_feed_target_names() fetch_target_names = program.desc.get_fetch_target_names() -- GitLab