未验证 提交 832032c2 编写于 作者: W WeiXin 提交者: GitHub

[cherry pick]修复save/load相关的两个bug (#30543)

原始PR:#30485,#30507
上级 dbbfbccd
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
import pickle import pickle
import contextlib import contextlib
from functools import reduce from functools import reduce
import sys
import numpy as np import numpy as np
import math import math
...@@ -1715,7 +1716,7 @@ def _unpack_saved_dict(saved_obj): ...@@ -1715,7 +1716,7 @@ def _unpack_saved_dict(saved_obj):
unpack_infor = {} unpack_infor = {}
for key, value in saved_obj.items(): for key, value in saved_obj.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = 2**22 MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape) num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT: if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {} unpack_infor[key] = {}
...@@ -1809,8 +1810,18 @@ def save(program, model_path): ...@@ -1809,8 +1810,18 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict) param_dict = _unpack_saved_dict(param_dict)
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
...@@ -2169,6 +2180,7 @@ def load_program_state(model_path, var_list=None): ...@@ -2169,6 +2180,7 @@ def load_program_state(model_path, var_list=None):
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
...@@ -2220,6 +2232,7 @@ def set_program_state(program, state_dict): ...@@ -2220,6 +2232,7 @@ def set_program_state(program, state_dict):
static.set_program_state(prog, program_state) static.set_program_state(prog, program_state)
""" """
state_dict = _pack_loaded_dict(state_dict)
parameter_list = list(filter(is_persistable, program.list_vars())) parameter_list = list(filter(is_persistable, program.list_vars()))
used_para_list = {} used_para_list = {}
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import os
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.optimizer as opt import paddle.optimizer as opt
...@@ -90,13 +91,13 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -90,13 +91,13 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
layer = LayerWithLargeParameters() layer = LayerWithLargeParameters()
save_dict = layer.state_dict() save_dict = layer.state_dict()
path = "test_paddle_save_load_large_param_save/layer" + ".pdparams" path = os.path.join("test_paddle_save_load_large_param_save",
"layer.pdparams")
paddle.save(layer.state_dict(), path) paddle.save(layer.state_dict(), path)
dict_load = paddle.load(path) dict_load = paddle.load(path)
# compare results before and after saving # compare results before and after saving
for key, value in save_dict.items(): for key, value in save_dict.items():
self.assertTrue( self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15)
class TestSaveLoad(unittest.TestCase): class TestSaveLoad(unittest.TestCase):
......
...@@ -1324,7 +1324,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -1324,7 +1324,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
name="static_save_load_large_x", name="static_save_load_large_x",
shape=[None, 10], shape=[None, 10],
dtype='float32') dtype='float32')
z = paddle.static.nn.fc(x, LARGE_PARAM) z = paddle.static.nn.fc(x, LARGE_PARAM, bias_attr=False)
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
...@@ -1334,16 +1334,55 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -1334,16 +1334,55 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
result_z = exe.run(program=prog, result_z = exe.run(program=prog,
feed={"static_save_load_large_x": inputs}, feed={"static_save_load_large_x": inputs},
fetch_list=[z.name]) fetch_list=[z.name])
path = "test_static_save_load_large_param/static_save" base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
path = os.path.join("test_static_save_load_large_param",
"static_save")
paddle.fluid.save(prog, path) paddle.fluid.save(prog, path)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
paddle.fluid.load(prog, path) paddle.fluid.load(prog, path)
result_load = exe.run(program=prog,
feed={"static_save_load_large_x": inputs}, for var in prog.list_vars():
fetch_list=[z.name]) if isinstance(var, framework.Parameter) or var.persistable:
# compare results before and after saving new_t = np.array(fluid.global_scope().find_var(var.name)
self.assertTrue( .get_tensor())
np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15) base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
program_state = fluid.load_program_state(path)
fluid.set_program_state(prog, program_state)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase): class TestProgramStateOldSaveSingleModel(unittest.TestCase):
......
...@@ -19,6 +19,7 @@ import collections ...@@ -19,6 +19,7 @@ import collections
import pickle import pickle
import six import six
import warnings import warnings
import sys
import paddle import paddle
...@@ -262,8 +263,17 @@ def save(obj, path): ...@@ -262,8 +263,17 @@ def save(obj, path):
saved_obj = _build_saved_state_dict(obj) saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj) saved_obj = _unpack_saved_dict(saved_obj)
with open(path, 'wb') as f: # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
pickle.dump(saved_obj, f, protocol=2) if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
def load(path, **configs): def load(path, **configs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册