diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index fdd236a58f0cfe3f947483f1bb3f214d723cbb62..1a7da4add31c4797a2116194e730c9e241125bc3 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -24,7 +24,7 @@ import contextlib from functools import reduce import numpy as np - +import math import paddle from paddle.fluid import layers from paddle.fluid.executor import Executor, global_scope @@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph): load_vars(executor=executor, dirname=dirname, vars=var_list) +def _unpack_saved_dict(saved_obj): + temp_saved_obj = {} + unpack_infor = {} + for key, value in saved_obj.items(): + if isinstance(value, np.ndarray): + MAX_NUMBER_OF_ELEMENT = 2**22 + num_element = np.prod(value.shape) + if num_element > MAX_NUMBER_OF_ELEMENT: + unpack_infor[key] = {} + unpack_infor[key]["OriginShape"] = value.shape + unpack_infor[key]["slices"] = [] + value = value.flatten() + for i in range( + int( + math.ceil(num_element * 1.0 / + MAX_NUMBER_OF_ELEMENT))): + part_name = key + "@@." + str(i) + unpack_infor[key]["slices"].append(part_name) + temp_saved_obj[part_name] = value[ + i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1 + )] + + if unpack_infor: + for key, value in unpack_infor.items(): + if key in saved_obj: + saved_obj.pop(key) + for part in value['slices']: + saved_obj[part] = temp_saved_obj[part] + saved_obj['UnpackBigParamInfor@@'] = unpack_infor + return saved_obj + + +def _pack_loaded_dict(load_obj): + unpack_info = 'UnpackBigParamInfor@@' + if unpack_info in load_obj: + removes = [] + for key, value in load_obj[unpack_info].items(): + slices = [load_obj[part] for part in value["slices"]] + load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"]) + removes += value["slices"] + for key in removes: + load_obj.pop(key) + load_obj.pop(unpack_info) + return load_obj + + @static_only def save(program, model_path): """ @@ -1762,6 +1808,7 @@ def save(program, model_path): parameter_list = list(filter(is_parameter, program.list_vars())) param_dict = {p.name: get_tensor(p) for p in parameter_list} + param_dict = _unpack_saved_dict(param_dict) with open(model_path + ".pdparams", 'wb') as f: pickle.dump(param_dict, f, protocol=2) @@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None): with open(parameter_file_name, 'rb') as f: load_dict = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') + load_dict = _pack_loaded_dict(load_dict) for v in parameter_list: assert v.name in load_dict, \ "Can not find [{}] in model file [{}]".format( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8803a02f4ccd10dd159f03757798b0ae7c8b518e..0081a53a8268460a9249978d2ab7dfedca0f10f8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -701,7 +701,13 @@ set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 120) +if (WIN32) + set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 300) + set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) +else() + set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 200) + set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 150) +endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index e211a38e7ec4cb0a8ebb79b173e8cb24e7d4ad03..3d5c8dfb480475dea24ebc011ed4940d70a254d2 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -28,6 +28,8 @@ SEED = 10 IMAGE_SIZE = 784 CLASS_NUM = 10 +LARGE_PARAM = 2**26 + def random_batch_reader(): def _get_random_inputs_and_labels(): @@ -57,6 +59,16 @@ class LinearNet(nn.Layer): return self._linear(x) +class LayerWithLargeParameters(paddle.nn.Layer): + def __init__(self): + super(LayerWithLargeParameters, self).__init__() + self._l = paddle.nn.Linear(10, LARGE_PARAM) + + def forward(self, x): + y = self._l(x) + return y + + def train(layer, loader, loss_fn, opt): for epoch_id in range(EPOCH_NUM): for batch_id, (image, label) in enumerate(loader()): @@ -67,6 +79,26 @@ def train(layer, loader, loss_fn, opt): opt.clear_grad() +class TestSaveLoadLargeParameters(unittest.TestCase): + def setUp(self): + pass + + def test_large_parameters_paddle_save(self): + # enable dygraph mode + paddle.disable_static() + # create network + layer = LayerWithLargeParameters() + save_dict = layer.state_dict() + + path = "test_paddle_save_load_large_param_save/layer" + ".pdparams" + paddle.save(layer.state_dict(), path) + dict_load = paddle.load(path) + # compare results before and after saving + for key, value in save_dict.items(): + self.assertTrue( + np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15) + + class TestSaveLoad(unittest.TestCase): def setUp(self): # enable dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index d7618add293f67206a61c564db6406fa091ca980..e275cb525bc871681ef8d345308d45c5f0572b10 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1199,6 +1199,39 @@ class TestProgramStateOldSave(unittest.TestCase): self.assertTrue(np.array_equal(new_t, base_t)) +class TestStaticSaveLoadLargeParameters(unittest.TestCase): + def test_large_parameters_static_save(self): + # enable static mode + paddle.enable_static() + LARGE_PARAM = 2**26 + with new_program_scope(): + # create network + x = paddle.static.data( + name="static_save_load_large_x", + shape=[None, 10], + dtype='float32') + z = paddle.static.nn.fc(x, LARGE_PARAM) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + + inputs = np.random.randn(1, 10).astype("float32") + result_z = exe.run(program=prog, + feed={"static_save_load_large_x": inputs}, + fetch_list=[z.name]) + path = "test_static_save_load_large_param/static_save" + paddle.fluid.save(prog, path) + + paddle.fluid.load(prog, path) + result_load = exe.run(program=prog, + feed={"static_save_load_large_x": inputs}, + fetch_list=[z.name]) + # compare results before and after saving + self.assertTrue( + np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15) + + class TestProgramStateOldSaveSingleModel(unittest.TestCase): def test_ptb_rnn_cpu_float32(self): seed = 90 diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index d794fce5e378dd71088636d7042d02f776a1fa93..66f843dc05ba082bcb046773273037ff4a3988a9 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -25,6 +25,7 @@ import paddle # deprecated module import from paddle import fluid from paddle.fluid import core +from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer from paddle.fluid.dygraph.jit import _SaveLoadConfig from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers @@ -259,6 +260,7 @@ def save(obj, path): # TODO(chenweihang): supports save other object saved_obj = _build_saved_state_dict(obj) + saved_obj = _unpack_saved_dict(saved_obj) with open(path, 'wb') as f: pickle.dump(saved_obj, f, protocol=2) @@ -338,7 +340,7 @@ def load(path, **configs): with open(path, 'rb') as f: load_result = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') - + load_result = _pack_loaded_dict(load_result) if not config.keep_name_table and "StructuredToParameterName@@" in load_result: del load_result["StructuredToParameterName@@"] else: