未验证 提交 e0a52fd7 编写于 作者: W WeiXin 提交者: GitHub

save/load program (#32336)

上级 f6f59e50
...@@ -447,5 +447,29 @@ class TestSaveLoad(unittest.TestCase): ...@@ -447,5 +447,29 @@ class TestSaveLoad(unittest.TestCase):
paddle.load("test_paddle_save_load.linear") paddle.load("test_paddle_save_load.linear")
class TestSaveLoadProgram(unittest.TestCase):
def test_save_load_program(self):
paddle.enable_static()
with new_program_scope():
layer = LinearNet()
data = paddle.static.data(
name='x_static_save', shape=(None, IMAGE_SIZE), dtype='float32')
y_static = layer(data)
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
origin_main = main_program.desc.serialize_to_string()
origin_startup = startup_program.desc.serialize_to_string()
path1 = "test_paddle_save_load_program/main_program.pdmodel"
path2 = "test_paddle_save_load_program/startup_program.pdmodel"
paddle.save(main_program, path1)
paddle.save(startup_program, path2)
with new_program_scope():
load_main = paddle.load(path1).desc.serialize_to_string()
load_startup = paddle.load(path2).desc.serialize_to_string()
self.assertTrue(origin_main == load_main)
self.assertTrue(origin_startup == load_startup)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -33,7 +33,7 @@ from paddle.fluid import core ...@@ -33,7 +33,7 @@ from paddle.fluid import core
from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac
from paddle.fluid.io import _legacy_save as _legacy_static_save from paddle.fluid.io import _legacy_save as _legacy_static_save
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place, Program
from paddle.fluid.dygraph.jit import _SaveLoadConfig from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
...@@ -453,8 +453,11 @@ def save(obj, path, protocol=2, **configs): ...@@ -453,8 +453,11 @@ def save(obj, path, protocol=2, **configs):
warnings.warn( warnings.warn(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead." "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
) )
if isinstance(obj, Program):
if _use_legacy(obj): obj.desc.flush()
with open(path, "wb") as f:
f.write(obj.desc.serialize_to_string())
elif _use_legacy(obj):
if in_dygraph_mode(): if in_dygraph_mode():
_legacy_save(obj, path, protocol) _legacy_save(obj, path, protocol)
else: else:
...@@ -627,6 +630,11 @@ def load(path, **configs): ...@@ -627,6 +630,11 @@ def load(path, **configs):
if os.path.isfile(path): if os.path.isfile(path):
config = _parse_load_config(configs) config = _parse_load_config(configs)
if six.PY2:
exception_type = KeyError
else:
exception_type = pickle.UnpicklingError
try:
with open(path, 'rb') as f: with open(path, 'rb') as f:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3: if sys.platform == 'darwin' and sys.version_info.major == 3:
...@@ -667,6 +675,11 @@ def load(path, **configs): ...@@ -667,6 +675,11 @@ def load(path, **configs):
raise NotImplementedError( raise NotImplementedError(
'Only support tensor and state_dict, but received {}.'. 'Only support tensor and state_dict, but received {}.'.
format(type(load_result))) format(type(load_result)))
except exception_type:
with open(path, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
return program
else: else:
load_result = _legacy_load(path, **configs) load_result = _legacy_load(path, **configs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册