未验证 提交 43367e4b 编写于 作者: W WeiXin 提交者: GitHub

support save/load single tensor (#31756)

* support save/load single tensor

* compatibility modification according to unnittest

* Some python2.7 don't have 'copyreg' modules

* Handle a syntax error.

* Dealing with compatibility problems on Mac.

* Dealing with compatibility problems on Mac.

* edit unittest to improve coverage.

* Modify the code according to the review comments

* Reduce redundant code.

* support for static graph loading dygraph state_dict

* edit code according to CI

* edit unittest

* edit unnittest

* delete redundant file

* edit code according to Comments

* edit english doc

* edit english doc

* edit English DOC.

* get/set_tensor->get/set_value; return_numpy=False

* get/set_tensor->get/set_value; return_numpy=False

* edit unnittest

* edit unnittest

* polish code.
上级 bf10d563
...@@ -22,6 +22,8 @@ import copy ...@@ -22,6 +22,8 @@ import copy
import weakref import weakref
import warnings import warnings
from copy import deepcopy from copy import deepcopy
import inspect
import paddle import paddle
from . import parallel_helper from . import parallel_helper
...@@ -1294,10 +1296,12 @@ class Layer(core.Layer): ...@@ -1294,10 +1296,12 @@ class Layer(core.Layer):
if state is None: if state is None:
raise ValueError("{} is not found in the provided dict.".format( raise ValueError("{} is not found in the provided dict.".format(
key)) key))
if list(state.shape) != list(param.shape): state_shape = state.shape() if inspect.ismethod(
state.shape) else state.shape
if list(state_shape) != list(param.shape):
raise ValueError( raise ValueError(
"{} receives a shape {}, but the expected shape is {}.". "{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape))) format(key, list(state_shape), list(param.shape)))
return param, state return param, state
matched_param_state = [] matched_param_state = []
......
...@@ -24,6 +24,7 @@ import re ...@@ -24,6 +24,7 @@ import re
import traceback import traceback
import six import six
import copy import copy
from types import MethodType, FunctionType
import numpy as np import numpy as np
import subprocess import subprocess
...@@ -1183,37 +1184,6 @@ class Variable(object): ...@@ -1183,37 +1184,6 @@ class Variable(object):
""" """
pass pass
@fake_interface_only
def set_value(self, value):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
Set a new value for this Variable.
Args:
value (Variable|np.ndarray): the new value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Linear
import numpy as np
data = np.ones([3, 1024], dtype='float32')
with fluid.dygraph.guard():
linear = fluid.dygraph.Linear(1024, 4)
t = to_variable(data)
linear(t) # call with default weight
custom_weight = np.random.randn(1024, 4).astype("float32")
linear.weight.set_value(custom_weight) # change existing weight
out = linear(t) # call with different weight
"""
pass
@fake_interface_only @fake_interface_only
def backward(self, retain_graph=False): def backward(self, retain_graph=False):
""" """
...@@ -2011,6 +1981,159 @@ class Variable(object): ...@@ -2011,6 +1981,159 @@ class Variable(object):
return self return self
def get_value(self, scope=None):
"""
Get the value of variable in given scope.
Args:
scope(Scope, optional) : If `scope` is None, it will be set to global scope
obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`.
Default: None
Returns:
Tensor: the value in given scope.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
import numpy as np
paddle.enable_static()
x = static.data(name="x", shape=[10, 10], dtype='float32')
y = static.nn.fc(x, 10, name='fc')
place = paddle.CPUPlace()
exe = static.Executor(place)
prog = paddle.static.default_main_program()
exe.run(static.default_startup_program())
inputs = np.ones((10, 10), dtype='float32')
exe.run(prog, feed={'x': inputs}, fetch_list=[y, ])
path = 'temp/tensor_'
for var in prog.list_vars():
if var.persistable:
t = var.get_value()
paddle.save(t, path+var.name+'.pdtensor')
for var in prog.list_vars():
if var.persistable:
t_load = paddle.load(path+var.name+'.pdtensor')
var.set_value(t_load)
"""
# The 'framework' is a low-level module, and 'executor'
# can not be imported at the begainning of this file.
# Therefore, the above two modules are dynamically imported.
from .executor import global_scope
if scope is not None and not isinstance(scope, core._Scope):
raise TypeError(
"`scope` should be None or `paddle.static.Scope` type, but received {}.".
format(type(scope)))
if scope is None:
scope = global_scope()
var_temp = scope.find_var(self.name)
if var_temp is None:
raise ValueError("Can not find Variable '{}' in the Scope.".format(
self.name))
t = var_temp.get_tensor()
return t
def set_value(self, value, scope=None):
'''
Set the value to the tensor in given scope.
Args:
value(Tensor/ndarray) : The value to be set.
scope(Scope, optional) : If `scope` is None, it will be set to global scope
obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`.
Default: None
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.static as static
import numpy as np
paddle.enable_static()
x = static.data(name="x", shape=[10, 10], dtype='float32')
y = static.nn.fc(x, 10, name='fc')
place = paddle.CPUPlace()
exe = static.Executor(place)
prog = paddle.static.default_main_program()
exe.run(static.default_startup_program())
inputs = np.ones((10, 10), dtype='float32')
exe.run(prog, feed={'x': inputs}, fetch_list=[y, ])
path = 'temp/tensor_'
for var in prog.list_vars():
if var.persistable:
t = var.get_value()
paddle.save(t, path+var.name+'.pdtensor')
for var in prog.list_vars():
if var.persistable:
t_load = paddle.load(path+var.name+'.pdtensor')
var.set_value(t_load)
'''
# The 'framework' is a low-level module, and 'executor'
# can not be imported at the begainning of this file.
# Therefore, the above two modules are dynamically imported.
from .executor import global_scope
if not (isinstance(value, np.ndarray) or hasattr(value, '__array__')):
raise TypeError(
"`value` should be `numpy.ndarray` or `LoDTensor`, but received {}.".
format(type(value)))
if scope is not None and not isinstance(scope, core._Scope):
raise TypeError(
"`scope` should be None or `paddle.static.Scope` type, but received {}.".
format(type(scope)))
if scope is None:
scope = global_scope()
var_temp = scope.find_var(self.name)
if var_temp is None:
raise ValueError("Can not find Variable '{}' in the Scope.".format(
self.name))
t = var_temp.get_tensor()
if hasattr(value, 'shape'):
if isinstance(value.shape, (MethodType, FunctionType)):
value_shape = value.shape()
else:
value_shape = value.shape
if list(t.shape()) != list(value_shape):
raise ValueError(
"{} expected a shape {}, but the received shape is {}.".
format(self.name, list(t.shape()), list(value_shape)))
p = t._place()
if p.is_cpu_place():
place = core.CPUPlace()
elif p.is_cuda_pinned_place():
place = core.CUDAPinnedPlace()
elif p.is_xpu_place():
p = core.Place()
p.set_place(t._place())
place = core.XPUPlace(p.xpu_device_id())
else:
p = core.Place()
p.set_place(t._place())
place = core.CUDAPlace(p.gpu_device_id())
t.set(value, place)
def get_all_op_protos(): def get_all_op_protos():
""" """
...@@ -5319,6 +5442,173 @@ class Program(object): ...@@ -5319,6 +5442,173 @@ class Program(object):
parameters.extend(each_block.all_parameters()) parameters.extend(each_block.all_parameters())
return parameters return parameters
def state_dict(self, mode='all', scope=None):
"""
Get parameters and persistable buffers of program as a dict. The key is the name of the parameter or the name of the buffer.
The value is the tensor of this variable in the given scope.
.. note::
This function MUST called after run start_up_program
Args:
mode(str, optional): Source of the obtained parameters and buffers.
'opt' : The return value only contains the variable in the optimizer.
'param' : The return value only contains the variable in the network, not the variable in the optimizer.
'all' : The return value contains the variable in the network and optimizer.
Default: 'all'
scope(Scope, optional) : If scope is None, state_dict will be set to global scope
obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope.
Default: None
Retruns:
dict: a dict contains the parameters and persistable buffers.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
x = static.data(name="x", shape=[10, 10], dtype='float32')
y = static.nn.fc(x, 10)
z = static.nn.fc(y, 10)
place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
prog = static.default_main_program()
path = "./temp/model.pdparams"
paddle.save(prog.state_dict(), path)
"""
# The 'framework' is a low-level module, and 'executor'
# can not be imported at the begainning of this file.
# Therefore, the above two modules are dynamically imported.
from .executor import global_scope
if scope is not None and not isinstance(scope, core._Scope):
raise TypeError(
"`scope` should be None or `paddle.static.Scope'` type, but received {}.".
format(type(scope)))
if scope is None:
scope = global_scope()
if not isinstance(mode, str):
raise TypeError("Type of `mode` should be string, but received {}.".
format(type(mode)))
def is_parameter(var):
return isinstance(var, Parameter)
def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
def is_belong_to_optimizer(var):
if not (isinstance(var, Parameter) or var.desc.need_check_feed()):
return is_persistable(var)
return False
def condition(var):
if mode == 'param':
return is_parameter(var)
elif mode == 'opt':
return is_belong_to_optimizer(var)
elif mode == 'all':
return is_parameter(var) or is_belong_to_optimizer(var)
else:
raise ValueError(
"`mode` string should be 'param', 'opt' or 'all', but received {}.".
format(mode))
var_list = filter(condition, self.list_vars())
state_dict = dict()
for var in var_list:
var_temp = scope.find_var(var.name)
if var_temp is None:
raise ValueError(
"Can not find Variable '{}' in the scope. Make sure it is initialized".
format(var.name))
state_dict[var.name] = var_temp.get_tensor()
return state_dict
def set_state_dict(self, state_dict, scope=None):
"""
Set parameters and persistable buffers in state_dict to program.
An exception will throw if shape or dtype of the parameters is not match.
.. note::
This function MUST called after run start_up_program
Args:
state_dict(dict): the dict store parameters and persistable buffers.
The key is the name of the parameter or the name of the buffer.
The value is the tensor of this variable in the given scope.
scope(Scope, optional) : If scope is None, state_dict will be set to global scope
obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope.
Default: None
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
x = static.data(name="x", shape=[10, 10], dtype='float32')
y = static.nn.fc(x, 10)
z = static.nn.fc(y, 10)
place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
prog = static.default_main_program()
path = "./temp/model.pdparams"
paddle.save(prog.state_dict(), path)
state_dict_load = paddle.load(path)
prog.set_state_dict(state_dict_load)
"""
if not isinstance(state_dict, dict):
raise TypeError(
"Type of `state_dict` should be dict, but received {}.".format(
type(state_dict)))
vars_dict = {var.name: var for var in self.list_vars()}
condition = True if 'StructuredToParameterName@@' in state_dict else False
for name, value in state_dict.items():
if condition:
if name == "StructuredToParameterName@@":
continue
if name in state_dict['StructuredToParameterName@@']:
name = state_dict['StructuredToParameterName@@'][name]
if name in vars_dict:
try:
vars_dict[name].set_value(value, scope)
except ValueError as err:
warnings.warn(
("Skip loading for '{}'. ".format(name) + str(err)))
except TypeError as err:
warnings.warn(
("Skip loading for '{}'. ".format(name) + str(err)))
else:
warnings.warn((
"Skip loading for '{0}'. Because '{0}' not in the program.".
format(name)))
@six.add_metaclass(ParameterMetaClass) @six.add_metaclass(ParameterMetaClass)
class Parameter(Variable): class Parameter(Variable):
......
...@@ -1765,7 +1765,30 @@ def _pack_loaded_dict(load_obj): ...@@ -1765,7 +1765,30 @@ def _pack_loaded_dict(load_obj):
@static_only @static_only
def save(program, model_path, pickle_protocol=2): def _legacy_save(param_dict, model_path, protocol=2):
def get_tensor(var):
if isinstance(var, core.VarBase):
return var.numpy()
elif isinstance(var, core.LoDTensor):
return np.array(var)
return var
param_dict = {name: get_tensor(param_dict[name]) for name in param_dict}
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
with open(model_path, 'wb') as f:
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
with open(model_path, 'wb') as f:
pickle.dump(param_dict, f, protocol=protocol)
@static_only
def save(program, model_path, protocol=2, **configs):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -1778,8 +1801,9 @@ def save(program, model_path, pickle_protocol=2): ...@@ -1778,8 +1801,9 @@ def save(program, model_path, pickle_protocol=2):
Args: Args:
program(Program) : The program to saved. program(Program) : The program to saved.
model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2 Default: 2
configs(dict, optional) : optional keyword arguments.
Returns: Returns:
None None
...@@ -1807,14 +1831,19 @@ def save(program, model_path, pickle_protocol=2): ...@@ -1807,14 +1831,19 @@ def save(program, model_path, pickle_protocol=2):
base_name = os.path.basename(model_path) base_name = os.path.basename(model_path)
assert base_name != "", \ assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string." "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
if 'pickle_protocol' in configs:
protocol = configs['pickle_protocol']
warnings.warn(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)
if not isinstance(pickle_protocol, int): if not isinstance(protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format( raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol))) type(protocol)))
if pickle_protocol < 2 or pickle_protocol > 4: if protocol < 2 or protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}". raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol)) format(protocol))
dir_name = os.path.dirname(model_path) dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
...@@ -1827,26 +1856,25 @@ def save(program, model_path, pickle_protocol=2): ...@@ -1827,26 +1856,25 @@ def save(program, model_path, pickle_protocol=2):
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = {p.name: get_tensor(p) for p in parameter_list}
param_dict = _unpack_saved_dict(param_dict, pickle_protocol) param_dict = _unpack_saved_dict(param_dict, protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3 and ( if sys.platform == 'darwin' and sys.version_info.major == 3:
sys.version_info.minor == 5 or sys.version_info.minor == 6): pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol)
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
max_bytes = 2**30 max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes): for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes]) f.write(pickle_bytes[i:i + max_bytes])
else: else:
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=pickle_protocol) pickle.dump(param_dict, f, protocol=protocol)
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f: with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=pickle_protocol) pickle.dump(opt_dict, f, protocol=protocol)
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
...@@ -1857,6 +1885,17 @@ def save(program, model_path, pickle_protocol=2): ...@@ -1857,6 +1885,17 @@ def save(program, model_path, pickle_protocol=2):
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
def _pickle_loads_mac(path, f):
pickle_bytes = bytearray(0)
file_size = os.path.getsize(path)
max_bytes = 2**30
for _ in range(0, file_size, max_bytes):
pickle_bytes += f.read(max_bytes)
load_result = pickle.loads(pickle_bytes) if six.PY2 else pickle.loads(
pickle_bytes, encoding='latin1')
return load_result
@static_only @static_only
def load(program, model_path, executor=None, var_list=None): def load(program, model_path, executor=None, var_list=None):
""" """
...@@ -2016,8 +2055,13 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -2016,8 +2055,13 @@ def load(program, model_path, executor=None, var_list=None):
global_scope(), global_scope(),
executor._default_executor) executor._default_executor)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
load_dict = _pickle_loads_mac(parameter_file_name, f)
else:
load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
load_dict = _pack_loaded_dict(load_dict) load_dict = _pack_loaded_dict(load_dict)
for v in parameter_list: for v in parameter_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
...@@ -2196,8 +2240,12 @@ def load_program_state(model_path, var_list=None): ...@@ -2196,8 +2240,12 @@ def load_program_state(model_path, var_list=None):
"Parameter file [{}] not exits".format(parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
f, encoding='latin1') if sys.platform == 'darwin' and sys.version_info.major == 3:
para_dict = _pickle_loads_mac(parameter_file_name, f)
else:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict) para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
......
...@@ -726,7 +726,7 @@ if (WIN32) ...@@ -726,7 +726,7 @@ if (WIN32)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250)
else() else()
set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 600) set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 600)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 150) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250)
endif() endif()
set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120)
......
...@@ -930,7 +930,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -930,7 +930,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams')) paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams'))
para_state_dict = paddle.load( para_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy.pdparams')) os.path.join('saved_dy', 'emb_dy.pdparams'), return_numpy=True)
para_state_dict['weight'] = np.expand_dims( para_state_dict['weight'] = np.expand_dims(
para_state_dict['weight'], axis=-1) para_state_dict['weight'], axis=-1)
......
...@@ -18,10 +18,15 @@ import unittest ...@@ -18,10 +18,15 @@ import unittest
import numpy as np import numpy as np
import os import os
import sys import sys
import six
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.optimizer as opt import paddle.optimizer as opt
import paddle.fluid as fluid
from paddle.fluid.optimizer import Adam
import paddle.fluid.framework as framework
from test_imperative_base import new_program_scope
BATCH_SIZE = 16 BATCH_SIZE = 16
BATCH_NUM = 4 BATCH_NUM = 4
...@@ -31,7 +36,10 @@ SEED = 10 ...@@ -31,7 +36,10 @@ SEED = 10
IMAGE_SIZE = 784 IMAGE_SIZE = 784
CLASS_NUM = 10 CLASS_NUM = 10
LARGE_PARAM = 2**26 if six.PY2:
LARGE_PARAM = 2**2
else:
LARGE_PARAM = 2**26
def random_batch_reader(): def random_batch_reader():
...@@ -95,15 +103,22 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -95,15 +103,22 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_paddle_save_load_large_param_save", path = os.path.join("test_paddle_save_load_large_param_save",
"layer.pdparams") "layer.pdparams")
paddle.save(layer.state_dict(), path) if six.PY2:
protocol = 2
else:
protocol = 4
paddle.save(save_dict, path, protocol=protocol)
dict_load = paddle.load(path) dict_load = paddle.load(path)
# compare results before and after saving # compare results before and after saving
for key, value in save_dict.items(): for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy())) self.assertTrue(
np.array_equal(dict_load[key].numpy(), value.numpy()))
class TestSaveLoadPickle(unittest.TestCase): class TestSaveLoadPickle(unittest.TestCase):
def test_pickle_protocol(self): def test_pickle_protocol(self):
# enable dygraph mode
paddle.disable_static()
# create network # create network
layer = LinearNet() layer = LinearNet()
save_dict = layer.state_dict() save_dict = layer.state_dict()
...@@ -124,11 +139,236 @@ class TestSaveLoadPickle(unittest.TestCase): ...@@ -124,11 +139,236 @@ class TestSaveLoadPickle(unittest.TestCase):
if sys.version_info.major >= 3 and sys.version_info.minor >= 4: if sys.version_info.major >= 3 and sys.version_info.minor >= 4:
protocols += [3, 4] protocols += [3, 4]
for protocol in protocols: for protocol in protocols:
paddle.save(save_dict, path, protocol) paddle.save(save_dict, path, pickle_protocol=protocol)
dict_load = paddle.load(path) dict_load = paddle.load(path)
# compare results before and after saving # compare results before and after saving
for key, value in save_dict.items(): for key, value in save_dict.items():
self.assertTrue(np.array_equal(dict_load[key], value.numpy())) self.assertTrue(
np.array_equal(dict_load[key].numpy(), value.numpy()))
class TestSaveLoadAny(unittest.TestCase):
def set_zero(self, prog, place, scope=None):
if scope is None:
scope = fluid.global_scope()
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = scope.find_var(var.name).get_tensor()
if ten is not None:
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(scope.find_var(var.name).get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
def replace_static_save(self, program, model_path, pickle_protocol=2):
with self.assertRaises(TypeError):
program.state_dict(1)
with self.assertRaises(TypeError):
program.state_dict(scope=1)
with self.assertRaises(ValueError):
program.state_dict('x')
state_dict_param = program.state_dict('param')
paddle.save(state_dict_param, model_path + '.pdparams')
state_dict_opt = program.state_dict('opt')
paddle.save(state_dict_opt, model_path + '.pdopt')
state_dict_all = program.state_dict()
paddle.save(state_dict_opt, model_path + '.pdall')
def replace_static_load(self, program, model_path):
with self.assertRaises(TypeError):
program.set_state_dict(1)
state_dict_param = paddle.load(model_path + '.pdparams')
state_dict_param['fake_var_name.@@'] = np.random.randn(1, 2)
state_dict_param['static_x'] = 'UserWarning'
program.set_state_dict(state_dict_param)
state_dict_param['static_x'] = np.random.randn(1, 2)
program.set_state_dict(state_dict_param)
program.set_state_dict(state_dict_param)
state_dict_opt = paddle.load(model_path + '.pdopt')
program.set_state_dict(state_dict_opt)
def test_replace_static_save_load(self):
paddle.enable_static()
with new_program_scope():
x = paddle.static.data(
name="static_x", shape=[None, IMAGE_SIZE], dtype='float32')
z = paddle.static.nn.fc(x, 10)
z = paddle.static.nn.fc(z, 10, bias_attr=False)
loss = fluid.layers.reduce_mean(z)
opt = Adam(learning_rate=1e-3)
opt.minimize(loss)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
fake_inputs = np.random.randn(2, IMAGE_SIZE).astype('float32')
exe.run(prog, feed={'static_x': fake_inputs}, fetch_list=[loss])
base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_map[var.name] = t
path = os.path.join("test_replace_static_save_load", "model")
# paddle.save, legacy paddle.fluid.load
self.replace_static_save(prog, path)
self.set_zero(prog, place)
paddle.fluid.io.load(prog, path)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, np.array(base_t)))
# legacy paddle.fluid.save, paddle.load
paddle.fluid.io.save(prog, path)
self.set_zero(prog, place)
self.replace_static_load(prog, path)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
# test for return tensor
path_vars = 'test_replace_save_load_return_tensor_static/model'
for var in prog.list_vars():
if var.persistable:
tensor = var.get_value(fluid.global_scope())
paddle.save(tensor, os.path.join(path_vars, var.name))
with self.assertRaises(TypeError):
var.get_value('fluid.global_scope()')
with self.assertRaises(ValueError):
x.get_value()
with self.assertRaises(TypeError):
x.set_value('1')
fake_data = np.zeros([3, 2, 1, 2, 3])
with self.assertRaises(TypeError):
x.set_value(fake_data, '1')
with self.assertRaises(ValueError):
x.set_value(fake_data)
with self.assertRaises(ValueError):
var.set_value(fake_data)
# set var to zero
self.set_zero(prog, place)
for var in prog.list_vars():
if var.persistable:
tensor = paddle.load(
os.path.join(path_vars, var.name), return_numpy=False)
var.set_value(tensor)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
def test_paddle_save_load_v2(self):
paddle.disable_static()
layer = LinearNet()
state_dict = layer.state_dict()
path = 'paddle_save_load_v2/model.pdparams'
with self.assertRaises(TypeError):
paddle.save(state_dict, path, use_binary_format='False')
# legacy paddle.save, paddle.load
paddle.framework.io._legacy_save(state_dict, path)
load_dict_tensor = paddle.load(path, return_numpy=False)
# legacy paddle.load, paddle.save
paddle.save(state_dict, path)
load_dict_np = paddle.framework.io._legacy_load(path)
for k, v in state_dict.items():
self.assertTrue(
np.array_equal(v.numpy(), load_dict_tensor[k].numpy()))
self.assertTrue(np.array_equal(v.numpy(), load_dict_np[k]))
def test_single_pickle_var_dygraph(self):
# enable dygraph mode
paddle.disable_static()
layer = LinearNet()
path = 'paddle_save_load_v2/var_dygraph'
tensor = layer._linear.weight
with self.assertRaises(ValueError):
paddle.save(tensor, path, pickle_protocol='3')
with self.assertRaises(ValueError):
paddle.save(tensor, path, pickle_protocol=5)
paddle.save(tensor, path)
t_dygraph = paddle.load(path)
np_dygraph = paddle.load(path, return_numpy=True)
self.assertTrue(isinstance(t_dygraph, paddle.fluid.core.VarBase))
self.assertTrue(np.array_equal(tensor.numpy(), np_dygraph))
self.assertTrue(np.array_equal(tensor.numpy(), t_dygraph.numpy()))
paddle.enable_static()
lod_static = paddle.load(path)
np_static = paddle.load(path, return_numpy=True)
self.assertTrue(isinstance(lod_static, paddle.fluid.core.LoDTensor))
self.assertTrue(np.array_equal(tensor.numpy(), np_static))
self.assertTrue(np.array_equal(tensor.numpy(), np.array(lod_static)))
def test_single_pickle_var_static(self):
# enable static mode
paddle.enable_static()
with new_program_scope():
# create network
x = paddle.static.data(
name="x", shape=[None, IMAGE_SIZE], dtype='float32')
z = paddle.static.nn.fc(x, 128)
loss = fluid.layers.reduce_mean(z)
place = fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
for var in prog.list_vars():
if list(var.shape) == [IMAGE_SIZE, 128]:
tensor = var.get_value()
break
scope = fluid.global_scope()
origin_tensor = np.array(tensor)
path = 'test_single_pickle_var_static/var'
paddle.save(tensor, path)
self.set_zero(prog, place, scope)
# static load
lod_static = paddle.load(path)
np_static = paddle.load(path, return_numpy=True)
# set_tensor(np.ndarray)
var.set_value(np_static, scope)
self.assertTrue(np.array_equal(origin_tensor, np.array(tensor)))
# set_tensor(LoDTensor)
self.set_zero(prog, place, scope)
var.set_value(lod_static, scope)
self.assertTrue(np.array_equal(origin_tensor, np.array(tensor)))
# enable dygraph mode
paddle.disable_static()
var_dygraph = paddle.load(path)
np_dygraph = paddle.load(path, return_numpy=True)
self.assertTrue(np.array_equal(np.array(tensor), np_dygraph))
self.assertTrue(np.array_equal(np.array(tensor), var_dygraph.numpy()))
def test_dygraph_save_static_load(self):
inps = np.random.randn(1, IMAGE_SIZE).astype('float32')
path = 'test_dygraph_save_static_load/dy-static.pdparams'
paddle.disable_static()
with paddle.utils.unique_name.guard():
layer = LinearNet()
state_dict_dy = layer.state_dict()
paddle.save(state_dict_dy, path)
paddle.enable_static()
with new_program_scope():
layer = LinearNet()
data = paddle.static.data(
name='x_static_save', shape=(None, IMAGE_SIZE), dtype='float32')
y_static = layer(data)
program = paddle.static.default_main_program()
place = fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
state_dict = paddle.load(path, keep_name_table=True)
program.set_state_dict(state_dict)
state_dict_param = program.state_dict("param")
for name, tensor in state_dict_dy.items():
self.assertTrue(
np.array_equal(tensor.numpy(),
np.array(state_dict_param[tensor.name])))
class TestSaveLoad(unittest.TestCase): class TestSaveLoad(unittest.TestCase):
...@@ -158,7 +398,9 @@ class TestSaveLoad(unittest.TestCase): ...@@ -158,7 +398,9 @@ class TestSaveLoad(unittest.TestCase):
def check_load_state_dict(self, orig_dict, load_dict): def check_load_state_dict(self, orig_dict, load_dict):
for var_name, value in orig_dict.items(): for var_name, value in orig_dict.items():
self.assertTrue(np.array_equal(value.numpy(), load_dict[var_name])) load_value = load_dict[var_name].numpy() if hasattr(
load_dict[var_name], 'numpy') else np.array(load_dict[var_name])
self.assertTrue(np.array_equal(value.numpy(), load_value))
def test_save_load(self): def test_save_load(self):
layer, opt = self.build_and_train_model() layer, opt = self.build_and_train_model()
......
...@@ -25,12 +25,17 @@ import six ...@@ -25,12 +25,17 @@ import six
import pickle import pickle
import os import os
# Python2.x no longer supports saving and loading large parameters.
if six.PY2:
LARGE_PARAM = 2
else:
LARGE_PARAM = 2**26
class TestStaticSaveLoadLargeParameters(unittest.TestCase): class TestStaticSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_static_save(self): def test_large_parameters_static_save(self):
# enable static mode # enable static mode
paddle.enable_static() paddle.enable_static()
LARGE_PARAM = 2**26
with new_program_scope(): with new_program_scope():
# create network # create network
x = paddle.static.data( x = paddle.static.data(
...@@ -54,7 +59,11 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -54,7 +59,11 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
path = os.path.join("test_static_save_load_large_param", path = os.path.join("test_static_save_load_large_param",
"static_save") "static_save")
paddle.fluid.save(prog, path) if six.PY2:
protocol = 2
else:
protocol = 4
paddle.fluid.save(prog, path, pickle_protocol=protocol)
# set var to zero # set var to zero
for var in prog.list_vars(): for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable: if isinstance(var, framework.Parameter) or var.persistable:
...@@ -92,3 +101,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -92,3 +101,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
.get_tensor()) .get_tensor())
base_t = base_map[var.name] base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t)) self.assertTrue(np.array_equal(new_t, base_t))
if __name__ == '__main__':
unittest.main()
...@@ -190,7 +190,6 @@ class TestVariable(unittest.TestCase): ...@@ -190,7 +190,6 @@ class TestVariable(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self.assertRaises(AssertionError, var.detach) self.assertRaises(AssertionError, var.detach)
self.assertRaises(AssertionError, var.numpy) self.assertRaises(AssertionError, var.numpy)
self.assertRaises(AssertionError, var.set_value, None)
self.assertRaises(AssertionError, var.backward) self.assertRaises(AssertionError, var.backward)
self.assertRaises(AssertionError, var.gradient) self.assertRaises(AssertionError, var.gradient)
self.assertRaises(AssertionError, var.clear_gradient) self.assertRaises(AssertionError, var.clear_gradient)
......
...@@ -22,13 +22,18 @@ import warnings ...@@ -22,13 +22,18 @@ import warnings
import sys import sys
import numpy as np import numpy as np
if not six.PY2:
import copyreg
import paddle import paddle
# deprecated module import # deprecated module import
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer from paddle.fluid.io import _legacy_save as _legacy_static_save
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place
from paddle.fluid.dygraph.jit import _SaveLoadConfig from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
...@@ -181,7 +186,9 @@ def _build_load_path_and_config(path, config): ...@@ -181,7 +186,9 @@ def _build_load_path_and_config(path, config):
def _parse_load_config(configs): def _parse_load_config(configs):
supported_configs = ['model_filename', 'params_filename', 'keep_name_table'] supported_configs = [
'model_filename', 'params_filename', 'keep_name_table', 'return_numpy'
]
# input check # input check
for key in configs: for key in configs:
...@@ -195,16 +202,158 @@ def _parse_load_config(configs): ...@@ -195,16 +202,158 @@ def _parse_load_config(configs):
inner_config.model_filename = configs.get('model_filename', None) inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None) inner_config.params_filename = configs.get('params_filename', None)
inner_config.keep_name_table = configs.get('keep_name_table', None) inner_config.keep_name_table = configs.get('keep_name_table', None)
inner_config.return_numpy = configs.get('return_numpy', False)
return inner_config return inner_config
def save(obj, path, pickle_protocol=2): def _parse_save_config(configs):
supported_configs = ['use_binary_format', 'pickle_protocol']
# input check
for key in configs:
if key not in supported_configs:
raise ValueError(
"The additional config (%s) of `paddle.save` is not supported."
% key)
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.use_binary_format = configs.get('use_binary_format', False)
inner_config.pickle_protocol = configs.get('pickle_protocol', None)
return inner_config
def _pickle_save(obj, f, protocol):
# TODO(weixin):add support for BytesIO.
if not isinstance(protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(protocol)))
if protocol < 2 or protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(protocol))
if not isinstance(obj, (core.LoDTensor, core.VarBase)):
raise NotImplementedError(
"Support 'paddle.Tensor' or 'paddle.core.LoDTensor', but received {}.".
format(type(obj)))
def reudce_varbase(self):
data = self.numpy()
name = self.name
return (tuple, ((name, data), ))
def reduce_LoDTensor(self):
data = np.array(self)
return (eval, ('data', {'data': data}))
def add_dispatch_table():
# This is not a good method, because the pickle module has been modified.
pickle.dispatch_table[core.VarBase] = reudce_varbase
pickle.dispatch_table[ParamBase] = reudce_varbase
pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor
def pop_dispatch_table():
pickle.dispatch_table.pop(core.VarBase)
pickle.dispatch_table.pop(core.LoDTensor)
pickle.dispatch_table.pop(ParamBase)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
add_dispatch_table()
pickle_bytes = pickle.dumps(obj)
pop_dispatch_table()
max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes])
else:
if six.PY2:
add_dispatch_table()
pickle_bytes = pickle.dump(obj, f, protocol)
pop_dispatch_table()
else:
pickler = pickle.Pickler(f, protocol)
pickler.dispatch_table = copyreg.dispatch_table.copy()
pickler.dispatch_table[core.VarBase] = reudce_varbase
pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickler.dispatch_table[ParamBase] = reudce_varbase
pickler.dump(obj)
def _use_legacy(obj):
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if not isinstance(obj, dict):
return False
return True
def _transformed_from_varbase(obj):
# In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()).
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
if isinstance(obj, tuple) and len(obj) == 2:
if six.PY2:
name_types = (str, unicode)
else:
name_types = str
if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray):
return True
return False
def _transformed_from_lodtensor(obj):
# In paddle2.1 version, LoDTensor is saved as np.array(tensor).
# When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor.
if isinstance(obj, np.ndarray):
return True
return False
def _to_LodTensor(ndarray):
if not isinstance(ndarray, np.ndarray):
raise TypeError(
'Type of `ndarray` should be numpy.ndarray, but received {}.'.
format(type(ndarray)))
t = core.LoDTensor()
place = _current_expected_place()
t.set(ndarray, place)
return t
def _tuple_to_tensor(obj, return_numpy):
if return_numpy:
return obj[1]
if in_dygraph_mode():
t = paddle.to_tensor(obj[1])
# This function does modify the name of return value.
# Loading the same variable multiple times may cause the same name.
t.name = obj[0]
return t
else:
return _to_LodTensor(obj[1])
def _ndarray_to_tensor(obj, return_numpy):
if return_numpy:
return obj
if in_dygraph_mode():
return paddle.to_tensor(obj)
else:
return _to_LodTensor(obj)
def save(obj, path, protocol=2, **configs):
''' '''
Save an object to the specified path. Save an object to the specified path.
.. note:: .. note::
Now only supports save ``state_dict`` of Layer or Optimizer. Now supports saving ``state_dict`` of Layer or Optimizer, Tensor.
.. note:: .. note::
Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file, Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file,
...@@ -219,8 +368,12 @@ def save(obj, path, pickle_protocol=2): ...@@ -219,8 +368,12 @@ def save(obj, path, pickle_protocol=2):
obj(Object) : The object to be saved. obj(Object) : The object to be saved.
path(str) : The path of the object to be saved. path(str) : The path of the object to be saved.
If saved in the current directory, the input path string will be used as the file name. If saved in the current directory, the input path string will be used as the file name.
pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 2 Default: 2
**configs(dict, optional): optional keyword arguments. The following options are currently supported:
use_binary_format(bool): When the saved object is static graph variable, you can specify ``use_binary_for_var``.
If True, save the file in the c++ binary format when saving a single static graph variable; otherwise, save it in pickle format.
Default: False
Returns: Returns:
None None
...@@ -228,20 +381,91 @@ def save(obj, path, pickle_protocol=2): ...@@ -228,20 +381,91 @@ def save(obj, path, pickle_protocol=2):
Examples: Examples:
.. code-block:: python .. code-block:: python
# example 1: dynamic graph
import paddle import paddle
emb = paddle.nn.Embedding(10, 10) emb = paddle.nn.Embedding(10, 10)
layer_state_dict = emb.state_dict() layer_state_dict = emb.state_dict()
# save state_dict of emb
paddle.save(layer_state_dict, "emb.pdparams") paddle.save(layer_state_dict, "emb.pdparams")
scheduler = paddle.optimizer.lr.NoamDecay(
scheduler = paddle.optimizer.lr.NoamDecay(
d_model=0.01, warmup_steps=100, verbose=True) d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=scheduler, learning_rate=scheduler,
parameters=emb.parameters()) parameters=emb.parameters())
opt_state_dict = adam.state_dict() opt_state_dict = adam.state_dict()
# save state_dict of optimizer
paddle.save(opt_state_dict, "adam.pdopt") paddle.save(opt_state_dict, "adam.pdopt")
# save weight of emb
paddle.save(emb.weight, "emb.weight.pdtensor")
# example 2: static graph
import paddle
import paddle.static as static
paddle.enable_static()
# create network
x = paddle.static.data(name="x", shape=[None, 224], dtype='float32')
z = paddle.static.nn.fc(x, 10)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
for var in prog.list_vars():
if list(var.shape) == [224, 10]:
tensor = var.get_tensor()
break
# save/load tensor
path_tensor = 'temp/tensor.pdtensor'
paddle.save(tensor, path_tensor)
# save/load state_dict
path_state_dict = 'temp/model.pdparams'
paddle.save(prog.state_dict("param"), path_tensor)
''' '''
# 1. input check
filename = os.path.basename(path)
if filename == "":
raise ValueError("The input path MUST be format of dirname/filename "
"[dirname\\filename in Windows system], but received "
"filename is empty string.")
# 2. save object
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
config = _parse_save_config(configs)
if not isinstance(config.use_binary_format, bool):
raise TypeError(
"Type of `use_binary_format` should be bool, but received {}.".
format(type(config.use_binary_format)))
# `protocol` need to be used, `pickle_protocol` is a deprecated arg.
if config.pickle_protocol is not None:
protocol = config.pickle_protocol
warnings.warn(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)
if _use_legacy(obj):
if in_dygraph_mode():
_legacy_save(obj, path, protocol)
else:
_legacy_static_save(obj, path, protocol)
else:
# save single variable
with open(path, 'wb') as f:
_pickle_save(obj, f, protocol)
def _legacy_save(obj, path, protocol=2):
# 1. input check # 1. input check
if not isinstance(obj, dict): if not isinstance(obj, dict):
raise NotImplementedError( raise NotImplementedError(
...@@ -257,13 +481,13 @@ def save(obj, path, pickle_protocol=2): ...@@ -257,13 +481,13 @@ def save(obj, path, pickle_protocol=2):
"[dirname\\filename in Windows system], but received " "[dirname\\filename in Windows system], but received "
"filename is empty string.") "filename is empty string.")
if not isinstance(pickle_protocol, int): if not isinstance(protocol, int):
raise ValueError("The 'protocol' MUST be `int`, but received {}".format( raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
type(pickle_protocol))) type(protocol)))
if pickle_protocol < 2 or pickle_protocol > 4: if protocol < 2 or protocol > 4:
raise ValueError("Expected 1<'protocol'<5, but received protocol={}". raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(pickle_protocol)) format(protocol))
# 2. save object # 2. save object
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
...@@ -274,19 +498,18 @@ def save(obj, path, pickle_protocol=2): ...@@ -274,19 +498,18 @@ def save(obj, path, pickle_protocol=2):
if isinstance(obj, dict): if isinstance(obj, dict):
saved_obj = _build_saved_state_dict(obj) saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol) saved_obj = _unpack_saved_dict(saved_obj, protocol)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3 and ( if sys.platform == 'darwin' and sys.version_info.major == 3:
sys.version_info.minor == 5 or sys.version_info.minor == 6): pickle_bytes = pickle.dumps(saved_obj, protocol=protocol)
pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol)
with open(path, 'wb') as f: with open(path, 'wb') as f:
max_bytes = 2**30 max_bytes = 2**30
for i in range(0, len(pickle_bytes), max_bytes): for i in range(0, len(pickle_bytes), max_bytes):
f.write(pickle_bytes[i:i + max_bytes]) f.write(pickle_bytes[i:i + max_bytes])
else: else:
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(saved_obj, f, protocol=pickle_protocol) pickle.dump(saved_obj, f, protocol=protocol)
def load(path, **configs): def load(path, **configs):
...@@ -294,7 +517,7 @@ def load(path, **configs): ...@@ -294,7 +517,7 @@ def load(path, **configs):
Load an object can be used in paddle from specified path. Load an object can be used in paddle from specified path.
.. note:: .. note::
Now only supports load ``state_dict`` of Layer or Optimizer. Now supports load ``state_dict`` of Layer or Optimizer, Tensor.
.. note:: .. note::
In order to use the model parameters saved by paddle more efficiently, In order to use the model parameters saved by paddle more efficiently,
...@@ -331,7 +554,9 @@ def load(path, **configs): ...@@ -331,7 +554,9 @@ def load(path, **configs):
``save_inference_model`` save format. Default file name is :code:`__model__` . ``save_inference_model`` save format. Default file name is :code:`__model__` .
(2) params_filename (str): The persistable variables file name of the paddle 1.x (2) params_filename (str): The persistable variables file name of the paddle 1.x
``save_inference_model`` save format. No default file name, save variables separately ``save_inference_model`` save format. No default file name, save variables separately
by default. by default.
(3) return_numpy(bool): If specified as True, return tensor as numpy.ndarray, otherwise return tensor as paddle.Tensor.
Default False.
Returns: Returns:
Object(Object): a target object can be used in paddle Object(Object): a target object can be used in paddle
...@@ -341,20 +566,115 @@ def load(path, **configs): ...@@ -341,20 +566,115 @@ def load(path, **configs):
import paddle import paddle
# example 1: dynamic graph
import paddle
emb = paddle.nn.Embedding(10, 10) emb = paddle.nn.Embedding(10, 10)
layer_state_dict = emb.state_dict() layer_state_dict = emb.state_dict()
# save state_dict of emb
paddle.save(layer_state_dict, "emb.pdparams") paddle.save(layer_state_dict, "emb.pdparams")
scheduler = paddle.optimizer.lr.NoamDecay(
scheduler = paddle.optimizer.lr.NoamDecay(
d_model=0.01, warmup_steps=100, verbose=True) d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=scheduler, learning_rate=scheduler,
parameters=emb.parameters()) parameters=emb.parameters())
opt_state_dict = adam.state_dict() opt_state_dict = adam.state_dict()
# save state_dict of optimizer
paddle.save(opt_state_dict, "adam.pdopt") paddle.save(opt_state_dict, "adam.pdopt")
# save weight of emb
paddle.save(emb.weight, "emb.weight.pdtensor")
# load state_dict of emb
load_layer_state_dict = paddle.load("emb.pdparams") load_layer_state_dict = paddle.load("emb.pdparams")
# load state_dict of optimizer
load_opt_state_dict = paddle.load("adam.pdopt") load_opt_state_dict = paddle.load("adam.pdopt")
# load weight of emb
load_weight = paddle.load("emb.weight.pdtensor")
# example 2: static graph
import paddle
import paddle.static as static
paddle.enable_static()
# create network
x = paddle.static.data(name="x", shape=[None, 224], dtype='float32')
z = paddle.static.nn.fc(x, 10)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
for var in prog.list_vars():
if list(var.shape) == [224, 10]:
tensor = var.get_tensor()
break
# save/load tensor
path_tensor = 'temp/tensor.pdtensor'
paddle.save(tensor, path_tensor)
load_tensor = paddle.load(path_tensor)
# save/load state_dict
path_state_dict = 'temp/model.pdparams'
paddle.save(prog.state_dict("param"), path_tensor)
load_state_dict = paddle.load(path_tensor)
''' '''
if os.path.isfile(path):
config = _parse_load_config(configs)
with open(path, 'rb') as f:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3:
load_result = _pickle_loads_mac(path, f)
else:
load_result = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if isinstance(load_result, dict):
if isinstance(load_result, dict):
load_result = _pack_loaded_dict(load_result)
# paddle2.0: paddle.save/load
if "StructuredToParameterName@@" in load_result:
for key in load_result["StructuredToParameterName@@"]:
load_result[key] = _ndarray_to_tensor(
load_result[key], config.return_numpy)
if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"]
else:
# paddle2.1 static.save/load
for key in load_result:
load_result[key] = _ndarray_to_tensor(
load_result[key], config.return_numpy)
else:
# TODO(weixin): support complex objects such as layer.
# If `obj` is any object, the judgment condition should be more precise.
if _transformed_from_lodtensor(load_result):
load_result = _ndarray_to_tensor(load_result,
config.return_numpy)
elif _transformed_from_varbase(load_result):
load_result = _tuple_to_tensor(load_result,
config.return_numpy)
else:
raise NotImplementedError(
'Only support tensor and state_dict, but received {}.'.
format(type(load_result)))
else:
load_result = _legacy_load(path, **configs)
return load_result
def _legacy_load(path, **configs):
load_result = None load_result = None
config = _parse_load_config(configs) config = _parse_load_config(configs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册