From db2e6cee620113590ef6f29eb65c495b7bab2d19 Mon Sep 17 00:00:00 2001 From: Shibo Tao <62922815+T8T9@users.noreply.github.com> Date: Wed, 18 Nov 2020 14:05:10 +0800 Subject: [PATCH] add two paddle-2.0 apis: paddle.static.io.save_inference_model and paddle.static.io.load_inference_model (#28606) * add two apis: paddle.static.io.save_inference_model and paddle.static.io.load_inference_mode, which are campatible with paddle.fluid.io.save_inference_model and paddle.fluid.io.load_inference_model respectively. * add unittest for new save_inference_model and load_inference_model. test=develop * enhance doc. test=develop * add paddle.enable_static() to test_inference_model_io.py. test=develop --- python/paddle/fluid/io.py | 61 +++- .../tests/unittests/rnn/test_rnn_nets.py | 8 +- .../unittests/test_inference_model_io.py | 116 +++++- python/paddle/static/__init__.py | 4 +- python/paddle/static/io.py | 335 ++++++++++++++++++ 5 files changed, 507 insertions(+), 17 deletions(-) create mode 100644 python/paddle/static/io.py diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index bb55aeb70d..29a6dcb135 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -43,6 +43,8 @@ from . import dataloader from .dataloader import * from . import core from .. import compat as cpt +from paddle.utils import deprecated +from paddle.fluid.framework import static_only batch = paddle.batch @@ -82,7 +84,10 @@ def is_parameter(var): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() param = fluid.default_main_program().global_block().var('fc.w') res = fluid.io.is_parameter(param) """ @@ -103,7 +108,10 @@ def is_persistable(var): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() param = fluid.default_main_program().global_block().var('fc.b') res = fluid.io.is_persistable(param) """ @@ -137,7 +145,10 @@ def get_program_parameter(program): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() data = fluid.data(name="img", shape=[64, 784]) w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w') b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b') @@ -162,7 +173,10 @@ def get_program_persistable_vars(program): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() data = fluid.data(name="img", shape=[64, 784]) w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w') b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b') @@ -202,7 +216,7 @@ def _load_program_scope(main=None, startup=None, scope=None): yield -def _get_valid_program(main_program): +def _get_valid_program(main_program=None): if main_program is None: main_program = default_main_program() elif isinstance(main_program, CompiledProgram): @@ -268,8 +282,10 @@ def save_vars(executor, Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() main_prog = fluid.Program() startup_prog = fluid.Program() with fluid.program_guard(main_prog, startup_prog): @@ -417,8 +433,11 @@ def save_params(executor, dirname, main_program=None, filename=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() params_path = "./my_paddle_model" image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='int64') @@ -465,7 +484,10 @@ def _save_distributed_persistables(executor, dirname, main_program): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" t = distribute_transpiler.DistributeTranspiler() @@ -634,8 +656,10 @@ def save_persistables(executor, dirname, main_program=None, filename=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() dir_path = "./my_paddle_model" file_name = "persistables" image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32') @@ -711,8 +735,10 @@ def load_vars(executor, Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() main_prog = fluid.Program() startup_prog = fluid.Program() with fluid.program_guard(main_prog, startup_prog): @@ -946,8 +972,10 @@ def load_params(executor, dirname, main_program=None, filename=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() @@ -995,8 +1023,10 @@ def load_persistables(executor, dirname, main_program=None, filename=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" prog = fluid.default_main_program() @@ -1034,7 +1064,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) param_path = "./my_paddle_model" t = distribute_transpiler.DistributeTranspiler() @@ -1160,7 +1193,8 @@ def append_fetch_ops(inference_program, attrs={'col': i}) -@dygraph_not_support +@static_only +@deprecated(since="2.0.0", update_to="paddle.static.save_inference_model") def save_inference_model(dirname, feeded_var_names, target_vars, @@ -1226,8 +1260,10 @@ def save_inference_model(dirname, Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() path = "./infer_model" # User defined network, here a softmax regession example @@ -1370,7 +1406,8 @@ def save_inference_model(dirname, return target_var_name_list -@dygraph_not_support +@static_only +@deprecated(since="2.0.0", update_to="paddle.static.load_inference_model") def load_inference_model(dirname, executor, model_filename=None, @@ -1422,9 +1459,11 @@ def load_inference_model(dirname, Examples: .. code-block:: python + import paddle import paddle.fluid as fluid import numpy as np + paddle.enable_static() # Build the model main_prog = fluid.Program() startup_prog = fluid.Program() @@ -1540,7 +1579,10 @@ def get_parameter_value(para, executor): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) param = fluid.default_main_program().global_block().var('fc.w') p = fluid.io.get_parameter_value(param, exe) @@ -1578,7 +1620,10 @@ def get_parameter_value_by_name(name, executor, program=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() exe = fluid.Executor(fluid.CPUPlace()) p = fluid.io.get_parameter_value('fc.w', exe) """ @@ -1686,8 +1731,10 @@ def save(program, model_path): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() prog = fluid.default_main_program() fluid.save( prog, "./temp") @@ -1753,8 +1800,10 @@ def load(program, model_path, executor=None, var_list=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + paddle.enable_static() prog = fluid.default_main_program() fluid.save( prog, "./temp") @@ -1914,7 +1963,10 @@ def load_program_state(model_path, var_list=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() x = fluid.data( name="x", shape=[10, 10], dtype='float32') y = fluid.layers.fc( x, 10) z = fluid.layers.fc( y, 10) @@ -2047,7 +2099,10 @@ def set_program_state(program, state_dict): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid + + paddle.enable_static() x = fluid.data( name="x", shape=[10, 10], dtype='float32') y = fluid.layers.fc( x, 10) z = fluid.layers.fc( y, 10) diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py index 87bdee8a91..639605a64e 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py @@ -323,10 +323,7 @@ def predict_test_util(place, mode): exe = paddle.static.Executor(place) [inference_program, feed_target_names, fetch_targets] = paddle.static.load_inference_model( - dirname="./inference", - executor=exe, - model_filename="%s_infer.pdmodel" % mode, - params_filename="%s_infer.pdiparams" % mode) + "./inference/%s_infer" % mode, exe) results = exe.run(inference_program, feed={feed_target_names[0]: x.numpy()}, fetch_list=fetch_targets) @@ -345,3 +342,6 @@ def load_tests(loader, tests, pattern): for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: suite.addTest(test_class(time_major, direction, device)) return suite + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index aa408aedf6..a82bc3f0f6 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -23,6 +23,7 @@ import paddle.fluid.core as core import paddle.fluid as fluid import warnings +import paddle import paddle.fluid.executor as executor import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer @@ -30,15 +31,17 @@ from paddle.fluid.compiler import CompiledProgram from paddle.fluid.framework import Program, program_guard from paddle.fluid.io import save_inference_model, load_inference_model, save_persistables from paddle.fluid.transpiler import memory_optimize +paddle.enable_static() -class TestBook(unittest.TestCase): - class InferModel(object): - def __init__(self, list): - self.program = list[0] - self.feed_var_names = list[1] - self.fetch_vars = list[2] +class InferModel(object): + def __init__(self, list): + self.program = list[0] + self.feed_var_names = list[1] + self.fetch_vars = list[2] + +class TestBook(unittest.TestCase): def test_fit_line_inference_model(self): MODEL_DIR = "./tmp/inference_model" UNI_MODEL_DIR = "./tmp/inference_model1" @@ -88,10 +91,10 @@ class TestBook(unittest.TestCase): six.moves.reload_module(executor) # reload to build a new scope - model_0 = self.InferModel(load_inference_model(MODEL_DIR, exe)) + model_0 = InferModel(load_inference_model(MODEL_DIR, exe)) with open(os.path.join(UNI_MODEL_DIR, 'model'), "rb") as f: model_str = f.read() - model_1 = self.InferModel( + model_1 = InferModel( load_inference_model(None, exe, model_str, params_str)) for model in [model_0, model_1]: @@ -192,6 +195,103 @@ class TestInstance(unittest.TestCase): [MODEL_DIR, ["x", "y"], [avg_cost], [], cp_prog]) +class TestSaveInferenceModelNew(unittest.TestCase): + def test_save_and_load_inference_model(self): + MODEL_DIR = "./tmp/inference_model5" + init_program = fluid.default_startup_program() + program = fluid.default_main_program() + + # fake program without feed/fetch + with program_guard(program, init_program): + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[1], dtype='float32') + + y_predict = layers.fc(input=x, size=1, act=None) + + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(cost) + + sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost, init_program) + + place = core.CPUPlace() + exe = executor.Executor(place) + exe.run(init_program, feed={}, fetch_list=[]) + + tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32") + tensor_y = np.array([[-2], [-3], [-7]]).astype("float32") + for i in six.moves.xrange(3): + exe.run(program, + feed={'x': tensor_x, + 'y': tensor_y}, + fetch_list=[avg_cost]) + + self.assertRaises(ValueError, paddle.static.save_inference_model, + None, ['x', 'y'], [avg_cost], exe) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR + "/", [x, y], [avg_cost], exe) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR, ['x', 'y'], [avg_cost], exe) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR, 'x', [avg_cost], exe) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR, [x, y], ['avg_cost'], exe) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR, [x, y], 'avg_cost', exe) + + model_path = MODEL_DIR + "_isdir.pdmodel" + os.makedirs(model_path) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) + os.rmdir(model_path) + + params_path = MODEL_DIR + "_isdir.pdmodel" + os.makedirs(params_path) + self.assertRaises(ValueError, paddle.static.save_inference_model, + MODEL_DIR + "_isdir", [x, y], [avg_cost], exe) + os.rmdir(params_path) + + paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], exe) + + self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel")) + self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams")) + + expected = exe.run(program, + feed={'x': tensor_x, + 'y': tensor_y}, + fetch_list=[avg_cost])[0] + + six.moves.reload_module(executor) # reload to build a new scope + + self.assertRaises(ValueError, paddle.static.load_inference_model, + None, exe) + self.assertRaises(ValueError, paddle.static.load_inference_model, + MODEL_DIR + "/", exe) + self.assertRaises(ValueError, paddle.static.load_inference_model, + [MODEL_DIR], exe) + self.assertRaises(ValueError, paddle.static.load_inference_model, + MODEL_DIR, exe, pserver_endpoints=None) + self.assertRaises(ValueError, paddle.static.load_inference_model, + MODEL_DIR, exe, unsupported_param=None) + self.assertRaises((TypeError, ValueError), paddle.static.load_inference_model, + None, exe, model_filename="illegal", params_filename="illegal") + + model = InferModel(paddle.static.io.load_inference_model(MODEL_DIR, exe)) + + outs = exe.run(model.program, + feed={ + model.feed_var_names[0]: tensor_x, + model.feed_var_names[1]: tensor_y + }, + fetch_list=model.fetch_vars) + actual = outs[0] + + self.assertEqual(model.feed_var_names, ["x", "y"]) + self.assertEqual(len(model.fetch_vars), 1) + self.assertEqual(expected, actual) + + + class TestLoadInferenceModelError(unittest.TestCase): def test_load_model_not_exist(self): place = core.CPUPlace() diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index a6ce437982..bca045852f 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -23,6 +23,8 @@ __all__ = [ ] from . import nn +from .io import save_inference_model +from .io import load_inference_model from ..fluid import Scope #DEFINE_ALIAS from .input import data #DEFINE_ALIAS from .input import InputSpec #DEFINE_ALIAS @@ -48,8 +50,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS from ..fluid.param_attr import WeightNormParamAttr #DEFINE_ALIAS from ..fluid.io import save #DEFINE_ALIAS from ..fluid.io import load #DEFINE_ALIAS -from ..fluid.io import save_inference_model #DEFINE_ALIAS -from ..fluid.io import load_inference_model #DEFINE_ALIAS from ..fluid.io import load_program_state #DEFINE_ALIAS from ..fluid.io import set_program_state #DEFINE_ALIAS from ..fluid.layers import create_parameter #DEFINE_ALIAS diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py new file mode 100644 index 0000000000..b30dfa8429 --- /dev/null +++ b/python/paddle/static/io.py @@ -0,0 +1,335 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + + +import errno +import inspect +import logging +import os +import six + +import paddle +from paddle.fluid import core, Variable, CompiledProgram, program_guard, default_main_program, Program +from paddle.fluid.framework import static_only +from paddle.fluid import layers + +from paddle.fluid.io import _get_valid_program, save_vars, _save_distributed_persistables +from paddle.fluid.io import prepend_feed_ops, append_fetch_ops, save_persistables +from paddle.fluid.io import load_persistables, _endpoints_replacement +from paddle.fluid.log_helper import get_logger + + +__all__ = [ + 'save_inference_model', + 'load_inference_model', +] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +def _check_args(caller, args, supported_args=[], deprecated_args=[]): + for arg in args: + if arg in deprecated_args: + raise ValueError("argument '{}' in function '{}' is deprecated, only {} are supported.".format(arg, caller, supported_args)) + elif arg not in supported_args: + raise ValueError( + "function '{}' doesn't support argument '{}',\n only {} are supported.".format(caller, arg, supported_args)) + + +@static_only +def save_inference_model(path_prefix, feed_vars, fetch_vars, executor): + """ + :api_attr: Static Graph + + Save current model and its parameters to given path. i.e. + Given path_prefix = "/path/to/modelname", after invoking + save_inference_model(path_prefix, feed_vars, fetch_vars, executor), + you will find two files named modelname.pdmodel and modelname.pdiparams + under "/path/to", which represent your model and parameters respectively. + + Args: + path_prefix(str): Directory path to save model + model name without suffix. + feed_vars(Variable | list[Variable]): Variables needed by inference. + fetch_vars(Variable | list[Variable]): Variables returned by inference. + executor(Executor): The executor that saves the inference model. You can refer + to :ref:`api_guide_executor_en` for more details. + Returns: + None + + Raises: + ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown. + ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + feeder = fluid.DataFeeder(feed_list=[image, label], place=fluid.CPUPlace()) + predict = fluid.layers.fc(input=image, size=10, act='softmax') + + loss = fluid.layers.cross_entropy(input=predict, label=label) + avg_loss = fluid.layers.mean(loss) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + # Feed data and train process + + # Save inference model. Note we don't save label and loss in this example + paddle.static.io.save_inference_model(path_prefix, [image], [predict], exe) + + # In this example, the save_inference_mode inference will prune the default + # main program according to the network's input node (img) and output node(predict). + # The pruned inference program is going to be saved in file "./infer_model.pdmodel" + # and parameters are going to be saved in file "./infer_model.pdiparams". + + """ + # check path_prefix, set model_path and params_path + if not isinstance(path_prefix, six.string_types): + raise ValueError("'path_prefix' should be a string.") + if path_prefix.endswith("/"): + raise ValueError("'path_prefix' should not be a directory") + path_prefix = os.path.normpath(path_prefix) + path_prefix = os.path.abspath(path_prefix) + try: + # mkdir may conflict if pserver and trainer are running on the same machine + dirname = os.path.dirname(path_prefix) + os.makedirs(dirname) + except OSError as e: + if e.errno != errno.EEXIST: + raise + model_path = path_prefix + ".pdmodel" + params_path = path_prefix + ".pdiparams" + if os.path.isdir(model_path): + raise ValueError("'{}' is an existing directory.".format(model_path)) + if os.path.isdir(params_path): + raise ValueError("'{}' is an existing directory.".format(params_path)) + + # verify feed_vars + if not isinstance(feed_vars, list): + feed_vars = [feed_vars] + if not feed_vars or not all([isinstance(var, Variable) for var in feed_vars]): + raise ValueError("'feed_vars' should be a Variable or a list of Variable.") + + # verify fetch_vars + if not isinstance(fetch_vars, list): + fetch_vars = [fetch_vars] + if not fetch_vars or not all([isinstance(var, Variable) for var in fetch_vars]): + raise ValueError("'fetch_vars' should be a Variable or a list of Variable.") + + main_program = _get_valid_program() + # remind users to set auc_states to 0 if auc op were found. + for op in main_program.global_block().ops: + # clear device of Op + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + op._set_attr(device_attr_name, "") + if op.type == 'auc': + warnings.warn("Be sure that you have set auc states to 0 before saving inference model.") + break + + # fix the bug that the activation op's output as target will be pruned. + # will affect the inference performance. + # TODO(Superjomn) add an IR pass to remove 1-scale op. + with program_guard(main_program): + uniq_fetch_vars = [] + for i, var in enumerate(fetch_vars): + var = layers.scale(var, 1., name="save_infer_model/scale_{}".format(i)) + uniq_fetch_vars.append(var) + fetch_vars = uniq_fetch_vars + + # save model + origin_program = main_program.clone() + main_program = main_program.clone() + global_block = main_program.global_block() + remove_op_idx = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + remove_op_idx.append(i) + for idx in remove_op_idx[::-1]: + global_block._remove_op(idx) + main_program.desc.flush() + + feed_var_names = [var.name for var in feed_vars] + main_program = main_program._prune_with_input( + feeded_var_names=feed_var_names, targets=fetch_vars) + main_program = main_program._inference_optimize(prune_read_op=True) + fetch_var_names = [var.name for var in fetch_vars] + prepend_feed_ops(main_program, feed_var_names) + append_fetch_ops(main_program, fetch_var_names) + main_program.desc._set_version() + paddle.fluid.core.save_op_version_info(main_program.desc) + with open(model_path, "wb") as f: + f.write(main_program.desc.serialize_to_string()) + main_program._copy_dist_param_info_from(origin_program) + + # save params + dirname = os.path.dirname(params_path) + basename = os.path.basename(params_path) + save_persistables(executor, dirname, main_program, basename) + + +@static_only +def load_inference_model(path_prefix, executor, **configs): + """ + :api_attr: Static Graph + + Load inference model from a given path. By this API, you can get the model + structure(Inference Program) and model parameters. + + Args: + path_prefix(str | None): One of the following: + - Directory path to save model + model name without suffix. + - Set to None when reading the model from memory. + executor(Executor): The executor to run for loading inference model. + See :ref:`api_guide_executor_en` for more details about it. + + Returns: + list: The return of this API is a list with three elements: + (program, feed_target_names, fetch_targets). The `program` is a + ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference. + The `feed_target_names` is a list of ``str``, which contains names of variables + that need to feed data in the inference program. The `fetch_targets` is a list of + ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which + we can get inference results. + + Raises: + ValueError: If `path_prefix.pdmodel` or `path_prefix.pdiparams` doesn't exist. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + paddle.enable_static() + + # Build the model + startup_prog = fluid.default_startup_program() + main_prog = fluid.default_main_program() + with fluid.program_guard(main_prog, startup_prog): + image = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False) + w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32') + b = fluid.layers.create_parameter(shape=[200], dtype='float32') + hidden_w = fluid.layers.matmul(x=image, y=w) + hidden_b = fluid.layers.elementwise_add(hidden_w, b) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + # Save the inference model + path_prefix = "./infer_model" + paddle.static.io.save_inference_model(path_prefix, [image], [hidden_b], exe) + + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.io.load_inference_model(path_prefix, exe)) + tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32) + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + # In this example, the inference program was saved in file + # "./infer_model.pdmodel" and parameters were saved in file + # " ./infer_model.pdiparams". + # By the inference program, feed_target_names and + # fetch_targets, we can use an executor to run the inference + # program to get the inference result. + """ + # check configs + supported_args = ('model_filename', 'params_filename') + deprecated_args = ('pserver_endpoints',) + caller = inspect.currentframe().f_code.co_name + _check_args(caller, configs, supported_args, deprecated_args) + + # load from memory + if path_prefix is None: + _logger.warning("Load inference model from memory is deprecated.") + model_filename = configs.get('model_filename', None) + params_filename = configs.get('params_filename', None) + if params_filename is None: + raise ValueError( + "params_filename cannot be None when path_prefix is None." + ) + load_dirname = path_prefix + program_desc_str = model_filename + params_filename = params_filename + # load from file + else: + # check and norm path_prefix + if not isinstance(path_prefix, six.string_types): + raise ValueError("'path_prefix' should be a string.") + if path_prefix.endswith("/"): + raise ValueError("'path_prefix' should not be a directory") + path_prefix = os.path.normpath(path_prefix) + path_prefix = os.path.abspath(path_prefix) + + # set model_path and params_path in new way, + # path_prefix represents a file path without suffix in this case. + if not configs: + model_path = path_prefix + ".pdmodel" + params_path = path_prefix + ".pdiparams" + # set model_path and params_path in old way for compatible, + # path_prefix represents a directory path. + else: + model_filename = configs.get('model_filename', None) + params_filename = configs.get('params_filename', None) + # set model_path + if model_filename is None: + model_path = os.path.join(path_prefix, "__model__") + else: + model_path = os.path.join(path_prefix, model_filename + ".pdmodel") + if not os.path.exists(model_path): + model_path = os.path.join(path_prefix, model_filename) + # set params_path + if params_filename is None: + params_path = os.path.join(path_prefix, "") + else: + params_path = os.path.join(path_prefix, params_filename + ".pdiparams") + if not os.path.exists(params_path): + params_path = os.path.join(path_prefix, params_filename) + _logger.warning("The old way to load inference model is deprecated." + " model path: {}, params path: {}".format(model_path, params_path)) + with open(model_path, "rb") as f: + program_desc_str = f.read() + load_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) + + program = Program.parse_from_string(program_desc_str) + if not core._is_program_version_supported(program._version()): + raise ValueError("Unsupported program version: %d\n" % + program._version()) + # Binary data also need versioning. + load_persistables(executor, load_dirname, program, params_filename) + + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] + + return [program, feed_target_names, fetch_targets] + -- GitLab