From 4424aac608ef32cb7cb3611e6c049b0fa8473288 Mon Sep 17 00:00:00 2001 From: Shibo Tao <62922815+T8T9@users.noreply.github.com> Date: Sat, 20 Feb 2021 15:55:40 +0800 Subject: [PATCH] export paddle.static.normalize_program method. (#31072) * export paddle.static.normalize_program method. test=develop * fix ut coverage.test=develop --- .../unittests/test_inference_model_io.py | 42 ++++++++++++ python/paddle/static/__init__.py | 1 + python/paddle/static/io.py | 65 +++++++++++++++++-- 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_inference_model_io.py b/python/paddle/fluid/tests/unittests/test_inference_model_io.py index 9a5d0b3e9b1..9abcf2a7676 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_model_io.py +++ b/python/paddle/fluid/tests/unittests/test_inference_model_io.py @@ -356,6 +356,48 @@ class TestSaveInferenceModelNew(unittest.TestCase): self.assertRaises(TypeError, paddle.static.io.deserialize_persistables, None, None, None) + def test_normalize_program(self): + init_program = fluid.default_startup_program() + program = fluid.default_main_program() + + # fake program without feed/fetch + with program_guard(program, init_program): + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[1], dtype='float32') + + y_predict = layers.fc(input=x, size=1, act=None) + + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(cost) + + sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost, init_program) + + place = core.CPUPlace() + exe = executor.Executor(place) + exe.run(init_program, feed={}, fetch_list=[]) + + tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32") + tensor_y = np.array([[-2], [-3], [-7]]).astype("float32") + for i in six.moves.xrange(3): + exe.run(program, + feed={'x': tensor_x, + 'y': tensor_y}, + fetch_list=[avg_cost]) + + # test if return type of serialize_program is bytes + res = paddle.static.normalize_program(program, [x, y], [avg_cost]) + self.assertTrue(isinstance(res, Program)) + # test program type + self.assertRaises(TypeError, paddle.static.normalize_program, None, + [x, y], [avg_cost]) + # test feed_vars type + self.assertRaises(TypeError, paddle.static.normalize_program, program, + 'x', [avg_cost]) + # test fetch_vars type + self.assertRaises(TypeError, paddle.static.normalize_program, program, + [x, y], 'avg_cost') + class TestLoadInferenceModelError(unittest.TestCase): def test_load_model_not_exist(self): diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 0ac5dbee5f8..91b4a29cefc 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -59,6 +59,7 @@ from .io import deserialize_program #DEFINE_ALIAS from .io import serialize_program #DEFINE_ALIAS from .io import load_from_file #DEFINE_ALIAS from .io import save_to_file #DEFINE_ALIAS +from .io import normalize_program #DEFINE_ALIAS from ..fluid import Scope #DEFINE_ALIAS from .input import data #DEFINE_ALIAS from .input import InputSpec #DEFINE_ALIAS diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 88740186178..6bbab6ed672 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -46,6 +46,7 @@ __all__ = [ 'deserialize_program', 'deserialize_persistables', 'load_from_file', + 'normalize_program', ] _logger = get_logger( @@ -127,10 +128,64 @@ def _clone_var_in_block(block, var): persistable=True) -def _normalize_program(program, feed_vars, fetch_vars): +def normalize_program(program, feed_vars, fetch_vars): """ - optimize program according feed_vars and fetch_vars. + :api_attr: Static Graph + + Normalize/Optimize a program according to feed_vars and fetch_vars. + + Args: + program(Program): Specify a program you want to optimize. + feed_vars(Variable | list[Variable]): Variables needed by inference. + fetch_vars(Variable | list[Variable]): Variables returned by inference. + + Returns: + Program: Normalized/Optimized program. + + Raises: + TypeError: If `program` is not a Program, an exception is thrown. + TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown. + TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + path_prefix = "./infer_model" + + # User defined network, here a softmax regession example + image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + predict = paddle.static.nn.fc(image, 10, activation='softmax') + + loss = paddle.nn.functional.cross_entropy(predict, label) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + # normalize main program. + program = default_main_program() + normalized_program = paddle.static.normalize_program(program, [image], [predict]) + """ + if not isinstance(program, Program): + raise TypeError( + "program type must be `fluid.Program`, but received `%s`" % + type(program)) + if not isinstance(feed_vars, list): + feed_vars = [feed_vars] + if not all(isinstance(v, Variable) for v in feed_vars): + raise TypeError( + "feed_vars type must be a Variable or a list of Variable.") + if not isinstance(fetch_vars, list): + fetch_vars = [fetch_vars] + if not all(isinstance(v, Variable) for v in fetch_vars): + raise TypeError( + "fetch_vars type must be a Variable or a list of Variable.") + # remind users to set auc_states to 0 if auc op were found. for op in program.global_block().ops: # clear device of Op @@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs): _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) - program = _normalize_program(program, feed_vars, fetch_vars) + program = normalize_program(program, feed_vars, fetch_vars) return _serialize_program(program) @@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs): _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) - program = _normalize_program(program, feed_vars, fetch_vars) + program = normalize_program(program, feed_vars, fetch_vars) return _serialize_persistables(program, executor) @@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) - program = _normalize_program(program, feed_vars, fetch_vars) + program = normalize_program(program, feed_vars, fetch_vars) # serialize and save program program_bytes = _serialize_program(program) save_to_file(model_path, program_bytes) -- GitLab