未验证 提交 832032c2 编写于 作者: W WeiXin 提交者: GitHub

[cherry pick]修复save/load相关的两个bug (#30543)

原始PR:#30485,#30507
上级 dbbfbccd
......@@ -22,6 +22,7 @@ import logging
import pickle
import contextlib
from functools import reduce
import sys
import numpy as np
import math
......@@ -1715,7 +1716,7 @@ def _unpack_saved_dict(saved_obj):
unpack_infor = {}
for key, value in saved_obj.items():
if isinstance(value, np.ndarray):
MAX_NUMBER_OF_ELEMENT = 2**22
MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize)
num_element = np.prod(value.shape)
if num_element > MAX_NUMBER_OF_ELEMENT:
unpack_infor[key] = {}
......@@ -1809,8 +1810,18 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict)
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(param_dict, protocol=2)
with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))
......@@ -2169,6 +2180,7 @@ def load_program_state(model_path, var_list=None):
with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name):
......@@ -2220,6 +2232,7 @@ def set_program_state(program, state_dict):
static.set_program_state(prog, program_state)
"""
state_dict = _pack_loaded_dict(state_dict)
parameter_list = list(filter(is_persistable, program.list_vars()))
used_para_list = {}
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import os
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
......@@ -90,13 +91,13 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
layer = LayerWithLargeParameters()
save_dict = layer.state_dict()
path = "test_paddle_save_load_large_param_save/layer" + ".pdparams"
path = os.path.join("test_paddle_save_load_large_param_save",
"layer.pdparams")
paddle.save(layer.state_dict(), path)
dict_load = paddle.load(path)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(
np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15)
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoad(unittest.TestCase):
......
......@@ -1324,7 +1324,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, LARGE_PARAM)
z = paddle.static.nn.fc(x, LARGE_PARAM, bias_attr=False)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
......@@ -1334,16 +1334,55 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
result_z = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
path = "test_static_save_load_large_param/static_save"
base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
path = os.path.join("test_static_save_load_large_param",
"static_save")
paddle.fluid.save(prog, path)
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
paddle.fluid.load(prog, path)
result_load = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
# compare results before and after saving
self.assertTrue(
np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
program_state = fluid.load_program_state(path)
fluid.set_program_state(prog, program_state)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
......
......@@ -19,6 +19,7 @@ import collections
import pickle
import six
import warnings
import sys
import paddle
......@@ -262,8 +263,17 @@ def save(obj, path):
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj)
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6'
if sys.platform == 'darwin' and sys.version_info.major == 3 and (
sys.version_info.minor == 5 or sys.version_info.minor == 6):
pickle_bytes = pickle.dumps(saved_obj, protocol=2)
with open(path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=2)
def load(path, **configs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册