From 5d045b95f34536b5ffa4997a7abf486428a78870 Mon Sep 17 00:00:00 2001 From: WeiXin <2279280558@qq.com> Date: Tue, 23 Feb 2021 19:43:43 +0800 Subject: [PATCH] =?UTF-8?q?=20[cherry-pick]=20paddle.save/paddle.static.sa?= =?UTF-8?q?ve=20=E5=8D=87=E7=BA=A7pickle=E7=9A=84=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E3=80=82=20(#31044)=20(#31140)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit paddle.save/static.save指定pickle版本参数protocol(2<= pickle_protocol <=4),默认pickle_protocol=2。 原始PR:#31044 --- python/paddle/fluid/io.py | 86 +++++++++++-------- .../tests/unittests/test_paddle_save_load.py | 31 +++++++ .../tests/unittests/test_static_save_load.py | 65 ++++++++++++++ python/paddle/framework/io.py | 23 +++-- 4 files changed, 166 insertions(+), 39 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 313855b6c55..9cca3e16de5 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1711,27 +1711,31 @@ def _load_persistable_nodes(executor, dirname, graph): load_vars(executor=executor, dirname=dirname, vars=var_list) -def _unpack_saved_dict(saved_obj): +def _unpack_saved_dict(saved_obj, protocol): temp_saved_obj = {} unpack_infor = {} - for key, value in saved_obj.items(): - if isinstance(value, np.ndarray): - MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize) - num_element = np.prod(value.shape) - if num_element > MAX_NUMBER_OF_ELEMENT: - unpack_infor[key] = {} - unpack_infor[key]["OriginShape"] = value.shape - unpack_infor[key]["slices"] = [] - value = value.flatten() - for i in range( - int( - math.ceil(num_element * 1.0 / - MAX_NUMBER_OF_ELEMENT))): - part_name = key + "@@." + str(i) - unpack_infor[key]["slices"].append(part_name) - temp_saved_obj[part_name] = value[ - i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1 - )] + # When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G. + if 1 < protocol < 4: + if isinstance(saved_obj, dict): + for key, value in saved_obj.items(): + if isinstance(value, np.ndarray): + MAX_NUMBER_OF_ELEMENT = int( + (2**30 - 1) / value.dtype.itemsize) + num_element = np.prod(value.shape) + if num_element > MAX_NUMBER_OF_ELEMENT: + unpack_infor[key] = {} + unpack_infor[key]["OriginShape"] = value.shape + unpack_infor[key]["slices"] = [] + value = value.flatten() + for i in range( + int( + math.ceil(num_element * 1.0 / + MAX_NUMBER_OF_ELEMENT))): + part_name = key + "@@." + str(i) + unpack_infor[key]["slices"].append(part_name) + temp_saved_obj[part_name] = value[ + i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT + * (i + 1)] if unpack_infor: for key, value in unpack_infor.items(): @@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj): def _pack_loaded_dict(load_obj): - unpack_info = 'UnpackBigParamInfor@@' - if unpack_info in load_obj: - removes = [] - for key, value in load_obj[unpack_info].items(): - slices = [load_obj[part] for part in value["slices"]] - load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"]) - removes += value["slices"] - for key in removes: - load_obj.pop(key) - load_obj.pop(unpack_info) + if isinstance(load_obj, dict): + unpack_info = 'UnpackBigParamInfor@@' + if unpack_info in load_obj: + removes = [] + for key, value in load_obj[unpack_info].items(): + slices = [load_obj[part] for part in value["slices"]] + load_obj[key] = np.concatenate(slices).reshape(value[ + "OriginShape"]) + removes += value["slices"] + for key in removes: + load_obj.pop(key) + load_obj.pop(unpack_info) + return load_obj @static_only -def save(program, model_path): +def save(program, model_path, pickle_protocol=2): """ :api_attr: Static Graph @@ -1771,6 +1778,8 @@ def save(program, model_path): Args: program(Program) : The program to saved. model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised + pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + Default: 2 Returns: None @@ -1799,6 +1808,14 @@ def save(program, model_path): assert base_name != "", \ "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string." + if not isinstance(pickle_protocol, int): + raise ValueError("The 'protocol' MUST be `int`, but received {}".format( + type(pickle_protocol))) + + if pickle_protocol < 2 or pickle_protocol > 4: + raise ValueError("Expected 1<'protocol'<5, but received protocol={}". + format(pickle_protocol)) + dir_name = os.path.dirname(model_path) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name) @@ -1809,26 +1826,27 @@ def save(program, model_path): parameter_list = list(filter(is_parameter, program.list_vars())) param_dict = {p.name: get_tensor(p) for p in parameter_list} - param_dict = _unpack_saved_dict(param_dict) + + param_dict = _unpack_saved_dict(param_dict, pickle_protocol) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' if sys.platform == 'darwin' and sys.version_info.major == 3 and ( sys.version_info.minor == 5 or sys.version_info.minor == 6): - pickle_bytes = pickle.dumps(param_dict, protocol=2) + pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol) with open(model_path + ".pdparams", 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: with open(model_path + ".pdparams", 'wb') as f: - pickle.dump(param_dict, f, protocol=2) + pickle.dump(param_dict, f, protocol=pickle_protocol) optimizer_var_list = list( filter(is_belong_to_optimizer, program.list_vars())) opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} with open(model_path + ".pdopt", 'wb') as f: - pickle.dump(opt_dict, f, protocol=2) + pickle.dump(opt_dict, f, protocol=pickle_protocol) main_program = program.clone() program.desc.flush() diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 3a8531db6f8..06f63d1416b 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -17,6 +17,8 @@ from __future__ import print_function import unittest import numpy as np import os +import sys + import paddle import paddle.nn as nn import paddle.optimizer as opt @@ -100,6 +102,35 @@ class TestSaveLoadLargeParameters(unittest.TestCase): self.assertTrue(np.array_equal(dict_load[key], value.numpy())) +class TestSaveLoadPickle(unittest.TestCase): + def test_pickle_protocol(self): + # create network + layer = LinearNet() + save_dict = layer.state_dict() + + path = os.path.join("test_paddle_save_load_pickle_protocol", + "layer.pdparams") + + with self.assertRaises(ValueError): + paddle.save(save_dict, path, 2.0) + + with self.assertRaises(ValueError): + paddle.save(save_dict, path, 1) + + with self.assertRaises(ValueError): + paddle.save(save_dict, path, 5) + + protocols = [2, ] + if sys.version_info.major >= 3 and sys.version_info.minor >= 4: + protocols += [3, 4] + for protocol in protocols: + paddle.save(save_dict, path, protocol) + dict_load = paddle.load(path) + # compare results before and after saving + for key, value in save_dict.items(): + self.assertTrue(np.array_equal(dict_load[key], value.numpy())) + + class TestSaveLoad(unittest.TestCase): def setUp(self): # enable dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index ca66aa47266..064c8277a90 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +import sys import unittest import paddle @@ -1444,6 +1445,70 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase): ]) +class TestStaticSaveLoadPickle(unittest.TestCase): + def test_pickle_protocol(self): + # enable static mode + paddle.enable_static() + + with new_program_scope(): + # create network + x = paddle.static.data( + name="static_save_load_large_x", + shape=[None, 10], + dtype='float32') + z = paddle.static.nn.fc(x, 10, bias_attr=False) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + + base_map = {} + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been update + self.assertTrue(np.sum(np.abs(t)) != 0) + base_map[var.name] = t + + path = os.path.join("test_static_save_load_pickle", + "pickle_protocol") + + with self.assertRaises(ValueError): + paddle.fluid.save(prog, path, 2.0) + + with self.assertRaises(ValueError): + paddle.fluid.save(prog, path, 1) + + with self.assertRaises(ValueError): + paddle.fluid.save(prog, path, 5) + + protocols = [2, ] + if sys.version_info.major >= 3 and sys.version_info.minor >= 4: + protocols += [3, 4] + for protocol in protocols: + paddle.fluid.save(prog, path, protocol) + # set var to zero + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = fluid.global_scope().find_var( + var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + paddle.fluid.load(prog, path) + + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 2dfad8dc10c..3d93bed32ec 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -20,6 +20,7 @@ import pickle import six import warnings import sys +import numpy as np import paddle @@ -198,7 +199,7 @@ def _parse_load_config(configs): return inner_config -def save(obj, path): +def save(obj, path, pickle_protocol=2): ''' Save an object to the specified path. @@ -218,6 +219,8 @@ def save(obj, path): obj(Object) : The object to be saved. path(str) : The path of the object to be saved. If saved in the current directory, the input path string will be used as the file name. + pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + Default: 2 Returns: None @@ -254,26 +257,36 @@ def save(obj, path): "[dirname\\filename in Windows system], but received " "filename is empty string.") + if not isinstance(pickle_protocol, int): + raise ValueError("The 'protocol' MUST be `int`, but received {}".format( + type(pickle_protocol))) + + if pickle_protocol < 2 or pickle_protocol > 4: + raise ValueError("Expected 1<'protocol'<5, but received protocol={}". + format(pickle_protocol)) + # 2. save object dirname = os.path.dirname(path) if dirname and not os.path.exists(dirname): os.makedirs(dirname) # TODO(chenweihang): supports save other object - saved_obj = _build_saved_state_dict(obj) - saved_obj = _unpack_saved_dict(saved_obj) + if isinstance(obj, dict): + saved_obj = _build_saved_state_dict(obj) + + saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' if sys.platform == 'darwin' and sys.version_info.major == 3 and ( sys.version_info.minor == 5 or sys.version_info.minor == 6): - pickle_bytes = pickle.dumps(saved_obj, protocol=2) + pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol) with open(path, 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: with open(path, 'wb') as f: - pickle.dump(saved_obj, f, protocol=2) + pickle.dump(saved_obj, f, protocol=pickle_protocol) def load(path, **configs): -- GitLab