未验证 提交 aa42bc25 编写于 作者: A Aurelius84 提交者: GitHub

[BugFix]Fix save/load_inference_model API BUG while program contains no param (#45038)

上级 b1e33bea
...@@ -1626,6 +1626,35 @@ class TestStaticSaveLoadPickle(unittest.TestCase): ...@@ -1626,6 +1626,35 @@ class TestStaticSaveLoadPickle(unittest.TestCase):
np.testing.assert_array_equal(new_t, base_t) np.testing.assert_array_equal(new_t, base_t)
class TestSaveLoadInferenceModel(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.model_path = os.path.join(self.temp_dir.name, 'no_params')
def tearDown(self):
self.temp_dir.cleanup()
def test_no_params(self):
main_program = framework.Program()
with framework.program_guard(main_program):
x = paddle.static.data(name="x", shape=[10, 10], dtype='float32')
y = x + x
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddle.static.save_inference_model(self.model_path, [x], [y], exe)
[inference_program, feed_target_names, fetch_targets
] = (paddle.static.load_inference_model(self.model_path, exe))
self.assertEqual(feed_target_names, ['x'])
self.assertEqual(fetch_targets[0].shape, (10, 10))
ops = [op.type for op in inference_program.block(0).ops]
self.assertEqual(ops, ['feed', 'elementwise_add', 'scale', 'fetch'])
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -542,7 +542,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, ...@@ -542,7 +542,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
save_to_file(model_path, program_bytes) save_to_file(model_path, program_bytes)
# serialize and save params # serialize and save params
params_bytes = _serialize_persistables(program, executor) params_bytes = _serialize_persistables(program, executor)
save_to_file(params_path, params_bytes) # program may not contain any parameter and just compute operation
if params_bytes is not None:
save_to_file(params_path, params_bytes)
@static_only @static_only
...@@ -660,6 +662,12 @@ def deserialize_persistables(program, data, executor): ...@@ -660,6 +662,12 @@ def deserialize_persistables(program, data, executor):
check_vars.append(var) check_vars.append(var)
load_var_map[var_copy.name] = var_copy load_var_map[var_copy.name] = var_copy
if data is None:
assert len(
origin_shape_map
) == 0, "Required 'data' shall be not None if program contains parameter, but received 'data' is None."
return
# append load_combine op to load parameters, # append load_combine op to load parameters,
load_var_list = [] load_var_list = []
for name in sorted(load_var_map.keys()): for name in sorted(load_var_map.keys()):
...@@ -849,7 +857,9 @@ def load_inference_model(path_prefix, executor, **kwargs): ...@@ -849,7 +857,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
params_filename = os.path.basename(params_path) params_filename = os.path.basename(params_path)
# load params data # load params data
params_path = os.path.join(load_dirname, params_filename) params_path = os.path.join(load_dirname, params_filename)
params_bytes = load_from_file(params_path) params_bytes = None
if os.path.exists(params_path):
params_bytes = load_from_file(params_path)
# deserialize bytes to program # deserialize bytes to program
program = deserialize_program(program_bytes) program = deserialize_program(program_bytes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册