From aa42bc25d87da8b93df99941324afc9620e2f9e1 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 10 Aug 2022 16:45:54 +0800 Subject: [PATCH] [BugFix]Fix save/load_inference_model API BUG while program contains no param (#45038) --- .../tests/unittests/test_static_save_load.py | 29 +++++++++++++++++++ python/paddle/static/io.py | 14 +++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index c47daba9684..2a30088a001 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1626,6 +1626,35 @@ class TestStaticSaveLoadPickle(unittest.TestCase): np.testing.assert_array_equal(new_t, base_t) +class TestSaveLoadInferenceModel(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.model_path = os.path.join(self.temp_dir.name, 'no_params') + + def tearDown(self): + self.temp_dir.cleanup() + + def test_no_params(self): + main_program = framework.Program() + with framework.program_guard(main_program): + x = paddle.static.data(name="x", shape=[10, 10], dtype='float32') + y = x + x + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + paddle.static.save_inference_model(self.model_path, [x], [y], exe) + + [inference_program, feed_target_names, fetch_targets + ] = (paddle.static.load_inference_model(self.model_path, exe)) + + self.assertEqual(feed_target_names, ['x']) + self.assertEqual(fetch_targets[0].shape, (10, 10)) + ops = [op.type for op in inference_program.block(0).ops] + self.assertEqual(ops, ['feed', 'elementwise_add', 'scale', 'fetch']) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index f0019df696b..de9e48b3367 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -542,7 +542,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, save_to_file(model_path, program_bytes) # serialize and save params params_bytes = _serialize_persistables(program, executor) - save_to_file(params_path, params_bytes) + # program may not contain any parameter and just compute operation + if params_bytes is not None: + save_to_file(params_path, params_bytes) @static_only @@ -660,6 +662,12 @@ def deserialize_persistables(program, data, executor): check_vars.append(var) load_var_map[var_copy.name] = var_copy + if data is None: + assert len( + origin_shape_map + ) == 0, "Required 'data' shall be not None if program contains parameter, but received 'data' is None." + return + # append load_combine op to load parameters, load_var_list = [] for name in sorted(load_var_map.keys()): @@ -849,7 +857,9 @@ def load_inference_model(path_prefix, executor, **kwargs): params_filename = os.path.basename(params_path) # load params data params_path = os.path.join(load_dirname, params_filename) - params_bytes = load_from_file(params_path) + params_bytes = None + if os.path.exists(params_path): + params_bytes = load_from_file(params_path) # deserialize bytes to program program = deserialize_program(program_bytes) -- GitLab