diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index c47daba96847ab681b37133d51826dec781b2ce3..2a30088a001ffdd23443dd2c8b1cd2d4130da33f 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1626,6 +1626,35 @@ class TestStaticSaveLoadPickle(unittest.TestCase): np.testing.assert_array_equal(new_t, base_t) +class TestSaveLoadInferenceModel(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.model_path = os.path.join(self.temp_dir.name, 'no_params') + + def tearDown(self): + self.temp_dir.cleanup() + + def test_no_params(self): + main_program = framework.Program() + with framework.program_guard(main_program): + x = paddle.static.data(name="x", shape=[10, 10], dtype='float32') + y = x + x + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + paddle.static.save_inference_model(self.model_path, [x], [y], exe) + + [inference_program, feed_target_names, fetch_targets + ] = (paddle.static.load_inference_model(self.model_path, exe)) + + self.assertEqual(feed_target_names, ['x']) + self.assertEqual(fetch_targets[0].shape, (10, 10)) + ops = [op.type for op in inference_program.block(0).ops] + self.assertEqual(ops, ['feed', 'elementwise_add', 'scale', 'fetch']) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index f0019df696b9713908a1fde1156db88b0f24ad7e..de9e48b3367cc70ec6139e9dcdf6859d926db9e9 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -542,7 +542,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, save_to_file(model_path, program_bytes) # serialize and save params params_bytes = _serialize_persistables(program, executor) - save_to_file(params_path, params_bytes) + # program may not contain any parameter and just compute operation + if params_bytes is not None: + save_to_file(params_path, params_bytes) @static_only @@ -660,6 +662,12 @@ def deserialize_persistables(program, data, executor): check_vars.append(var) load_var_map[var_copy.name] = var_copy + if data is None: + assert len( + origin_shape_map + ) == 0, "Required 'data' shall be not None if program contains parameter, but received 'data' is None." + return + # append load_combine op to load parameters, load_var_list = [] for name in sorted(load_var_map.keys()): @@ -849,7 +857,9 @@ def load_inference_model(path_prefix, executor, **kwargs): params_filename = os.path.basename(params_path) # load params data params_path = os.path.join(load_dirname, params_filename) - params_bytes = load_from_file(params_path) + params_bytes = None + if os.path.exists(params_path): + params_bytes = load_from_file(params_path) # deserialize bytes to program program = deserialize_program(program_bytes)