未验证 提交 727b28d7 编写于 作者: W WeiXin 提交者: GitHub

paddle.save/load support nested structure and layer (#32446)

* support save/load binary format tensor

* Fix error when create cudaplace

* Fix error when create cudaplace

* Fix error when create cudaplace

* get devive context from pool.

* move define of 'SerializeToStream' and 'DeserializeFromStream' to 'lod_tensor.cc' and 'selected_rows.cc'.

* support complex object

* improve coverage.

* improve coverage

* improve coverage.

* fix a bug.

* polish API

* save/load program

* paddle.save/load: layer

* deal with conflict

* if PY2, block test_paddle_save_load.TestSaveLoadLayer

* polish code.

* polish code

* edit unnittest

* The condition for object to be identified as state_dict becomes strict

* use 'core._cuda_synchronize'
上级 74824fdd
......@@ -235,11 +235,6 @@ def _pickle_save(obj, f, protocol):
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(protocol))
if not isinstance(obj, (core.LoDTensor, core.VarBase)):
raise NotImplementedError(
"Support 'paddle.Tensor' or 'paddle.core.LoDTensor', but received {}.".
format(type(obj)))
def reudce_varbase(self):
data = self.numpy()
name = self.name
......@@ -287,11 +282,48 @@ def _pickle_save(obj, f, protocol):
pickler.dump(obj)
def _use_legacy(obj):
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if not isinstance(obj, dict):
def _contain_x(obj, condition_func):
if isinstance(obj, core.SelectedRows):
raise NotImplementedError(
"`paddle.save` do not support saving 'SelectedRows'.")
if condition_func(obj):
return True
elif type(obj) in (dict, collections.OrderedDict, list, tuple):
if type(obj) in (dict, collections.OrderedDict):
keys = list(obj.keys())
else:
keys = range(len(obj))
flag = False
for key in keys:
flag |= _contain_x(obj[key], condition_func)
if flag:
return True
return flag
else:
return False
return True
def _is_state_dict(obj):
if isinstance(obj, dict):
def condition(obj):
return isinstance(obj, (core.Layer, Program, core.VarBase,
core.LoDTensor, core.SelectedRows))
# If the value of a dict is a core.VarBase/LoDTensor or a dict
# that does not contain a paddle type(Layer, Program, VarBase, LoDTensor, SelectedRows),
# the dict is considered to be a state_ dict.
for key, value in obj.items():
if isinstance(value, dict):
for k, v in value.items():
if _contain_x(v, condition):
return False
elif not isinstance(value, (core.VarBase, core.LoDTensor)):
return False
return True
return False
def _transformed_from_varbase(obj):
......@@ -348,6 +380,76 @@ def _ndarray_to_tensor(obj, return_numpy):
return _to_LodTensor(obj)
def _lod_tensor2varbase(tensor):
return_var = _varbase_creator()
return_var.value().get_tensor().set(tensor, _current_expected_place())
return return_var
def _parse_every_object(obj, condition_func, convert_func):
if condition_func(obj):
return convert_func(obj)
elif type(obj) in (dict, collections.OrderedDict, list):
if type(obj) == list:
keys = range(len(obj))
else:
keys = list(obj.keys())
for key in keys:
if condition_func(obj[key]):
obj[key] = convert_func(obj[key])
else:
obj[key] = _parse_every_object(obj[key], condition_func,
convert_func)
return obj
elif type(obj) == tuple:
return tuple(
_parse_every_object(list(obj), condition_func, convert_func))
elif type(obj) == set:
return set(_parse_every_object(list(obj), condition_func, convert_func))
else:
if isinstance(obj, collections.Iterable) and not isinstance(obj, (
str, np.ndarray, core.VarBase, core.LoDTensor)):
raise NotImplementedError(
"The iteratable objects supported are tuple, list, dict, OrderedDict, string. But received {}.".
format(type(obj)))
return obj
def _parse_load_result(obj, return_numpy):
def is_layer(obj):
return isinstance(obj, core.Layer)
def parse_layer(obj):
temp_dict = _parse_load_result(obj.__dict__, False)
obj.__dict__.update(temp_dict)
return obj
if _contain_x(obj, is_layer):
if not in_dygraph_mode():
raise ValueError(
"Layer can only be loaded in dynamic graph mode, but now in static graph mode."
)
_parse_every_object(obj, is_layer, parse_layer)
def tuple_to_tensor(obj):
return _tuple_to_tensor(obj, return_numpy=return_numpy)
def ndarray_to_tensor(obj):
return _ndarray_to_tensor(obj, return_numpy=return_numpy)
# tuple(name, ndarry) was converted from varbase of paddle2.1,
# and all tuple(name, ndarry) are converted to tensor.
if _contain_x(obj, _transformed_from_varbase):
return _parse_every_object(obj, _transformed_from_varbase,
tuple_to_tensor)
# If there is no tuple(name, ndary), it is considered to be saved by paddle2.0
# or converted from LoDTensor, and all ndarrays are converted to tensor.
else:
return _parse_every_object(obj, _transformed_from_lodtensor,
ndarray_to_tensor)
def _save_lod_tensor(tensor, file_name):
if not tensor._is_initialized():
raise ValueError("The saved tensor is not initialized.")
......@@ -383,6 +485,8 @@ def _save_binary_var(obj, path):
_save_lod_tensor(obj, path)
elif isinstance(obj, core.SelectedRows):
_save_selected_rows(obj, path)
elif isinstance(obj, core.VarBase):
_save_lod_tensor(obj.value().get_tensor(), path)
else:
# Since the concept of 'Tensor' is only exposed to users, the error message can only contain tensor instead of 'LoDTensor' or 'SelectedRows'
raise NotImplementedError(
......@@ -498,32 +602,20 @@ def save(obj, path, protocol=2, **configs):
warnings.warn(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)
if isinstance(obj, Program):
obj.desc.flush()
with open(path, "wb") as f:
f.write(obj.desc.serialize_to_string())
elif _use_legacy(obj):
elif _is_state_dict(obj):
if in_dygraph_mode():
_legacy_save(obj, path, protocol)
else:
_legacy_static_save(obj, path, protocol)
else:
# `protocol` need to be used, `pickle_protocol` is a deprecated arg.
if config.pickle_protocol is not None:
protocol = config.pickle_protocol
warnings.warn(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)
if _use_legacy(obj):
if in_dygraph_mode():
_legacy_save(obj, path, protocol)
else:
_legacy_static_save(obj, path, protocol)
else:
# save single variable
with open(path, 'wb') as f:
_pickle_save(obj, f, protocol)
with open(path, 'wb') as f:
_pickle_save(obj, f, protocol)
def _legacy_save(obj, path, protocol=2):
......@@ -703,8 +795,7 @@ def load(path, **configs):
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if isinstance(load_result, dict):
if isinstance(load_result, dict):
load_result = _pack_loaded_dict(load_result)
load_result = _pack_loaded_dict(load_result)
# paddle2.0: paddle.save/load
if "StructuredToParameterName@@" in load_result:
......@@ -716,23 +807,12 @@ def load(path, **configs):
del load_result["StructuredToParameterName@@"]
else:
# paddle2.1 static.save/load
for key in load_result:
load_result[key] = _ndarray_to_tensor(
load_result[key], config.return_numpy)
load_result = _parse_load_result(load_result,
config.return_numpy)
else:
# TODO(weixin): support complex objects such as layer.
# If `obj` is any object, the judgment condition should be more precise.
if _transformed_from_lodtensor(load_result):
load_result = _ndarray_to_tensor(load_result,
config.return_numpy)
elif _transformed_from_varbase(load_result):
load_result = _tuple_to_tensor(load_result,
config.return_numpy)
else:
raise NotImplementedError(
'Only support tensor and state_dict, but received {}.'.
format(type(load_result)))
load_result = _parse_load_result(load_result,
config.return_numpy)
except exception_type as msg_pickle:
try:
......@@ -741,7 +821,12 @@ def load(path, **configs):
except:
try:
tensor, _ = _load_lod_tensor(path)
return tensor
if config.return_numpy:
return np.array(tensor)
else:
if in_dygraph_mode():
return _lod_tensor2varbase(tensor)
return tensor
except:
try:
with open(path, "rb") as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册