未验证 提交 f43e1d8c 编写于 作者: W WeiXin 提交者: GitHub

Support storage of large parameters (#29988)

* Support storage of large parameters

* Reduce the complexity of the unittest

* Reduce the complexity of the unittest,commented out unittest for

* add unittest for static.save/load

* Increase the timeout threshold of 'test_static_save_load'

* Increase the timeout threshold of 'test_static_save_load'

* Increase the timeout threshold of 'test_static_save_load' and 'test_paddle_save_load'

* Increase the timeout threshold of 'test_static_save_load' and 'test_paddle_save_load'
上级 666e6651
......@@ -24,7 +24,7 @@ import contextlib
from functools import reduce
import numpy as np
import math
import paddle
from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope
......@@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars(executor=executor, dirname=dirname, vars=var_list)
def _unpack_saved_dict(saved_obj):
temp_saved_obj = {}
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = 2**22
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
unpack_infor[key]["OriginShape"] = value.shape
unpack_infor[key]["slices"] = []
value = value.flatten()
for i in range(
int(
math.ceil(num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT))):
part_name = key + "@@." + str(i)
unpack_infor[key]["slices"].append(part_name)
temp_saved_obj[part_name] = value[
i * MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT * (i + 1
)]
if unpack_infor:
for key, value in unpack_infor.items():
if key in saved_obj:
saved_obj.pop(key)
for part in value['slices']:
saved_obj[part] = temp_saved_obj[part]
saved_obj['UnpackBigParamInfor@@'] = unpack_infor
return saved_obj
def _pack_loaded_dict(load_obj):
unpack_info = 'UnpackBigParamInfor@@'
if unpack_info in load_obj:
removes = []
for key, value in load_obj[unpack_info].items():
slices = [load_obj[part] for part in value["slices"]]
load_obj[key] = np.concatenate(slices).reshape(value["OriginShape"])
removes += value["slices"]
for key in removes:
load_obj.pop(key)
load_obj.pop(unpack_info)
return load_obj
@static_only
def save(program, model_path):
"""
......@@ -1762,6 +1808,7 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
......@@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None):
with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_dict = _pack_loaded_dict(load_dict)
for v in parameter_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
......
......@@ -697,7 +697,8 @@ set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 200)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 150)
set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120)
......
......@@ -28,6 +28,8 @@ SEED = 10
IMAGE_SIZE = 784
CLASS_NUM = 10
LARGE_PARAM = 2**26
def random_batch_reader():
def _get_random_inputs_and_labels():
......@@ -57,6 +59,16 @@ class LinearNet(nn.Layer):
return self._linear(x)
class LayerWithLargeParameters(paddle.nn.Layer):
def __init__(self):
super(LayerWithLargeParameters, self).__init__()
self._l = paddle.nn.Linear(10, LARGE_PARAM)
def forward(self, x):
y = self._l(x)
return y
def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
......@@ -67,6 +79,26 @@ def train(layer, loader, loss_fn, opt):
opt.clear_grad()
class TestSaveLoadLargeParameters(unittest.TestCase):
def setUp(self):
pass
def test_large_parameters_paddle_save(self):
# enable dygraph mode
paddle.disable_static()
# create network
layer = LayerWithLargeParameters()
save_dict = layer.state_dict()
path = "test_paddle_save_load_large_param_save/layer" + ".pdparams"
paddle.save(layer.state_dict(), path)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(
np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15)
class TestSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
......@@ -1199,6 +1199,39 @@ class TestProgramStateOldSave(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
class TestStaticSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_static_save(self):
# enable static mode
paddle.enable_static()
LARGE_PARAM = 2**26
with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, LARGE_PARAM)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
inputs = np.random.randn(1, 10).astype("float32")
result_z = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
path = "test_static_save_load_large_param/static_save"
paddle.fluid.save(prog, path)
paddle.fluid.load(prog, path)
result_load = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
# compare results before and after saving
self.assertTrue(
np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15)
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
......
......@@ -25,6 +25,7 @@ import paddle
# deprecated module import
from paddle import fluid
from paddle.fluid import core
from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer
from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
......@@ -259,6 +260,7 @@ def save(obj, path):
# TODO(chenweihang): supports save other object
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
......@@ -338,7 +340,7 @@ def load(path, **configs):
with open(path, 'rb') as f:
load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_result = _pack_loaded_dict(load_result)
if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册