未验证 提交 ee1801c1 编写于 作者: W WeiXin 提交者: GitHub

Save load/save pickle protocol (#31044)

* add default argument  for paddle.save/static.save

* edit documentation of

* Add comments for special processing for protocol=2 and protocol=3.

* Update python/paddle/fluid/io.py
Co-authored-by: Nlanxianghit <47554610+lanxianghit@users.noreply.github.com>
Co-authored-by: Nlanxianghit <47554610+lanxianghit@users.noreply.github.com>
上级 cced930b
...@@ -1711,12 +1711,16 @@ def _load_persistable_nodes(executor, dirname, graph): ...@@ -1711,12 +1711,16 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars(executor=executor, dirname=dirname, vars=var_list) load_vars(executor=executor, dirname=dirname, vars=var_list)
def _unpack_saved_dict(saved_obj): def _unpack_saved_dict(saved_obj, protocol):
temp_saved_obj = {} temp_saved_obj = {}
unpack_infor = {} unpack_infor = {}
# When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G.
if 1 < protocol < 4:
if isinstance(saved_obj, dict):
for key, value in saved_obj.items(): for key, value in saved_obj.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize) MAX_NUMBER_OF_ELEMENT = int(
(2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape) num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT: if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {} unpack_infor[key] = {}
...@@ -1730,8 +1734,8 @@ def _unpack_saved_dict(saved_obj): ...@@ -1730,8 +1734,8 @@ def _unpack_saved_dict(saved_obj):
part_name = key + "@@." + str(i) part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name) unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[ temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1 i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT
)] * (i + 1)]
if unpack_infor: if unpack_infor:
for key, value in unpack_infor.items(): for key, value in unpack_infor.items():
...@@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj): ...@@ -1744,21 +1748,24 @@ def _unpack_saved_dict(saved_obj):
def _pack_loaded_dict(load_obj): def _pack_loaded_dict(load_obj):
if isinstance(load_obj, dict):
unpack_info = 'UnpackBigParamInfor@@' unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj: if unpack_info in load_obj:
removes = [] removes = []
for key, value in load_obj[unpack_info].items(): for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]] slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"]) load_obj[key] = np.concatenate(slices).reshape(value[
"OriginShape"])
removes += value["slices"] removes += value["slices"]
for key in removes: for key in removes:
load_obj.pop(key) load_obj.pop(key)
load_obj.pop(unpack_info) load_obj.pop(unpack_info)
return load_obj return load_obj
@static_only @static_only
def save(program, model_path): def save(program, model_path, pickle_protocol=2):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -1771,6 +1778,8 @@ def save(program, model_path): ...@@ -1771,6 +1778,8 @@ def save(program, model_path):
Args: Args:
program(Program) : The program to saved. program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns: Returns:
None None
...@@ -1799,6 +1808,14 @@ def save(program, model_path): ...@@ -1799,6 +1808,14 @@ def save(program, model_path):
assert base_name != "", \ assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string." "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))
if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))
dir_name = os.path.dirname(model_path) dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name) os.makedirs(dir_name)
...@@ -1809,26 +1826,27 @@ def save(program, model_path): ...@@ -1809,26 +1826,27 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)
param_dict = _unpack_saved_dict(param_dict, pickle_protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and ( if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6): sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2) pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30 max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes): for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes]) f.write(pickle_bytes[i:i + max_bytes])
else: else:
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2) pickle.dump(param_dict, f, protocol=pickle_protocol)
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f: with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2) pickle.dump(opt_dict, f, protocol=pickle_protocol)
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import os import os
import sys
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.optimizer as opt import paddle.optimizer as opt
...@@ -100,6 +102,35 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -100,6 +102,35 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
self.assertTrue(np.array_equal(dict_load[key], value.numpy())) self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# create network
layer = LinearNet()
save_dict = layer.state_dict()
path = os.path.join("test_paddle_save_load_pickle_protocol",
"layer.pdparams")
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 2.0)
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 1)
with self.assertRaises(ValueError):
paddle.save(save_dict, path, 5)
protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.save(save_dict, path, protocol)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoad(unittest.TestCase): class TestSaveLoad(unittest.TestCase):
def setUp(self): def setUp(self):
# enable dygraph mode # enable dygraph mode
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import sys
import unittest import unittest
import paddle import paddle
...@@ -1452,6 +1453,70 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase): ...@@ -1452,6 +1453,70 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase):
]) ])
class TestStaticSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self):
# enable static mode
paddle.enable_static()
with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, 10, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
path = os.path.join("test_static_save_load_pickle",
"pickle_protocol")
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 2.0)
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 1)
with self.assertRaises(ValueError):
paddle.fluid.save(prog, path, 5)
protocols = [2, ]
if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4]
for protocol in protocols:
paddle.fluid.save(prog, path, protocol)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(
var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
paddle.fluid.load(prog, path)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -20,6 +20,7 @@ import pickle ...@@ -20,6 +20,7 @@ import pickle
import six import six
import warnings import warnings
import sys import sys
import numpy as np
import paddle import paddle
...@@ -198,7 +199,7 @@ def _parse_load_config(configs): ...@@ -198,7 +199,7 @@ def _parse_load_config(configs):
return inner_config return inner_config
def save(obj, path): def save(obj, path, pickle_protocol=2):
''' '''
Save an object to the specified path. Save an object to the specified path.
...@@ -218,6 +219,8 @@ def save(obj, path): ...@@ -218,6 +219,8 @@ def save(obj, path):
obj(Object) : The object to be saved. obj(Object) : The object to be saved.
path(str) : The path of the object to be saved. path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name. If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2
Returns: Returns:
None None
...@@ -254,26 +257,36 @@ def save(obj, path): ...@@ -254,26 +257,36 @@ def save(obj, path):
"[dirname\\filename in Windows system], but received " "[dirname\\filename in Windows system], but received "
"filename is empty string.") "filename is empty string.")
if not isinstance(pickle_protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol)))
if pickle_protocol < 2 or pickle_protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol))
# 2. save object # 2. save object
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname): if dirname and not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
# TODO(chenweihang): supports save other object # TODO(chenweihang): supports save other object
if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj) saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and ( if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6): sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2) pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f: with open(path, 'wb') as f:
max_bytes = 2**30 max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes): for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes]) f.write(pickle_bytes[i:i + max_bytes])
else: else:
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2) pickle.dump(saved_obj, f, protocol=pickle_protocol)
def load(path, **configs): def load(path, **configs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册