未验证 提交 bfb6f613 编写于 作者: W WeiXin 提交者: GitHub

[cherry pick] paddle.save/load ,paddle.static.save/load 保存大文件的bug (#30170)

* Support storage of large parameters (#29988)

* Support storage of large parameters

* Reduce the complexity of the unittest

* Reduce the complexity of the unittest,commented out unittest for

* add unittest for static.save/load

* Increase the timeout threshold of 'test_static_save_load'

* Increase the timeout threshold of 'test_static_save_load'

* Increase the timeout threshold of 'test_static_save_load' and 'test_paddle_save_load'

* Increase the timeout threshold of 'test_static_save_load' and 'test_paddle_save_load'

* Extend the timeout for the (#30151)
上级 9f02c284
...@@ -24,7 +24,7 @@ import contextlib ...@@ -24,7 +24,7 @@ import contextlib
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import math
import paddle import paddle
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
...@@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph): ...@@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars(executor=executor, dirname=dirname, vars=var_list) load_vars(executor=executor, dirname=dirname, vars=var_list)
def _unpack_saved_dict(saved_obj):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = 2**22
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]
if unpack_infor:
for key, value in unpack_infor.items():
if key in saved_obj:
saved_obj.pop(key)
for part in value['slices']:
saved_obj[part] = temp_saved_obj[part]
saved_obj['UnpackBigParamInfor@@'] = unpack_infor
return saved_obj
def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
return load_obj
@static_only @static_only
def save(program, model_path): def save(program, model_path):
""" """
...@@ -1762,6 +1808,7 @@ def save(program, model_path): ...@@ -1762,6 +1808,7 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2) pickle.dump(param_dict, f, protocol=2)
...@@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None):
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load( load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
load_dict = _pack_loaded_dict(load_dict)
for v in parameter_list: for v in parameter_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
......
...@@ -701,7 +701,13 @@ set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) ...@@ -701,7 +701,13 @@ set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 120) if (WIN32)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 300)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250)
else()
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 200)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 150)
endif()
set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120)
......
...@@ -28,6 +28,8 @@ SEED = 10 ...@@ -28,6 +28,8 @@ SEED = 10
IMAGE_SIZE = 784 IMAGE_SIZE = 784
CLASS_NUM = 10 CLASS_NUM = 10
LARGE_PARAM = 2**26
def random_batch_reader(): def random_batch_reader():
def _get_random_inputs_and_labels(): def _get_random_inputs_and_labels():
...@@ -57,6 +59,16 @@ class LinearNet(nn.Layer): ...@@ -57,6 +59,16 @@ class LinearNet(nn.Layer):
return self._linear(x) return self._linear(x)
class LayerWithLargeParameters(paddle.nn.Layer):
def __init__(self):
super(LayerWithLargeParameters, self).__init__()
self._l = paddle.nn.Linear(10, LARGE_PARAM)
def forward(self, x):
y = self._l(x)
return y
def train(layer, loader, loss_fn, opt): def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM): for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()): for batch_id, (image, label) in enumerate(loader()):
...@@ -67,6 +79,26 @@ def train(layer, loader, loss_fn, opt): ...@@ -67,6 +79,26 @@ def train(layer, loader, loss_fn, opt):
opt.clear_grad() opt.clear_grad()
class TestSaveLoadLargeParameters(unittest.TestCase):
def setUp(self):
pass
def test_large_parameters_paddle_save(self):
# enable dygraph mode
paddle.disable_static()
# create network
layer = LayerWithLargeParameters()
save_dict = layer.state_dict()
path = "test_paddle_save_load_large_param_save/layer" + ".pdparams"
paddle.save(layer.state_dict(), path)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(
np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15)
class TestSaveLoad(unittest.TestCase): class TestSaveLoad(unittest.TestCase):
def setUp(self): def setUp(self):
# enable dygraph mode # enable dygraph mode
......
...@@ -1199,6 +1199,39 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1199,6 +1199,39 @@ class TestProgramStateOldSave(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t)) self.assertTrue(np.array_equal(new_t, base_t))
class TestStaticSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_static_save(self):
# enable static mode
paddle.enable_static()
LARGE_PARAM = 2**26
with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, LARGE_PARAM)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
inputs = np.random.randn(1, 10).astype("float32")
result_z = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
path = "test_static_save_load_large_param/static_save"
paddle.fluid.save(prog, path)
paddle.fluid.load(prog, path)
result_load = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
# compare results before and after saving
self.assertTrue(
np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15)
class TestProgramStateOldSaveSingleModel(unittest.TestCase): class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self): def test_ptb_rnn_cpu_float32(self):
seed = 90 seed = 90
......
...@@ -25,6 +25,7 @@ import paddle ...@@ -25,6 +25,7 @@ import paddle
# deprecated module import # deprecated module import
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer
from paddle.fluid.dygraph.jit import _SaveLoadConfig from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
...@@ -259,6 +260,7 @@ def save(obj, path): ...@@ -259,6 +260,7 @@ def save(obj, path):
# TODO(chenweihang): supports save other object # TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj) saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2) pickle.dump(saved_obj, f, protocol=2)
...@@ -338,7 +340,7 @@ def load(path, **configs): ...@@ -338,7 +340,7 @@ def load(path, **configs):
with open(path, 'rb') as f: with open(path, 'rb') as f:
load_result = pickle.load(f) if six.PY2 else pickle.load( load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
load_result = _pack_loaded_dict(load_result)
if not config.keep_name_table and "StructuredToParameterName@@" in load_result: if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"] del load_result["StructuredToParameterName@@"]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册