未验证 提交 5d045b95 编写于 作者: W WeiXin 提交者: GitHub

[cherry-pick] paddle.save/paddle.static.save 升级pickle的版本。 (#31044) (#31140)

paddle.save/static.save指定pickle版本参数protocol(2<= pickle_protocol <=4),默认pickle_protocol=2。
原始PR:#31044
上级 30a2e7f0
......@@ -1711,27 +1711,31 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars(executor=executor, dirname=dirname, vars=var_list)
def _unpack_saved_dict(saved_obj):
def _unpack_saved_dict(saved_obj, protocol):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]
# When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G.
if 1 < protocol < 4:
if isinstance(saved_obj, dict):
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int(
(2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT
* (i + 1)]
if unpack_infor:
for key, value in unpack_infor.items():
......@@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj):
def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
if isinstance(load_obj, dict):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value[
"OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
return load_obj
@static_only
def save(program, model_path):
def save(program, model_path, pickle_protocol=2):
"""
:api_attr: Static Graph
......@@ -1771,6 +1778,8 @@ def save(program, model_path):
Args:
program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
......@@ -1799,6 +1808,14 @@ def save(program, model_path):
assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))
if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))
dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
......@@ -1809,26 +1826,27 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)
param_dict = _unpack_saved_dict(param_dict, pickle_protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
pickle.dump(param_dict, f, protocol=pickle_protocol)
optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2)
pickle.dump(opt_dict, f, protocol=pickle_protocol)
main_program = program.clone()
program.desc.flush()
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
import os
import sys
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
......@@ -100,6 +102,35 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# create network
layer = LinearNet()
save_dict = layer.state_dict()
path = os.path.join("test_paddle_save_load_pickle_protocol",
"layer.pdparams")
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 2.0)
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 1)
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 5)
protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.save(save_dict, path, protocol)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import sys
import unittest
import paddle
......@@ -1444,6 +1445,70 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase):
])
class TestStaticSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# enable static mode
paddle.enable_static()
with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
path = os.path.join("test_static_save_load_pickle",
"pickle_protocol")
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 2.0)
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 1)
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 5)
protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.fluid.save(prog, path, protocol)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(
var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
paddle.fluid.load(prog, path)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -20,6 +20,7 @@ import pickle
import six
import warnings
import sys
import numpy as np
import paddle
......@@ -198,7 +199,7 @@ def _parse_load_config(configs):
return inner_config
def save(obj, path):
def save(obj, path, pickle_protocol=2):
'''
Save an object to the specified path.
......@@ -218,6 +219,8 @@ def save(obj, path):
obj(Object) : The object to be saved.
path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns:
None
......@@ -254,26 +257,36 @@ def save(obj, path):
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))
if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
# TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
pickle.dump(saved_obj, f, protocol=pickle_protocol)
def load(path, **configs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册