未验证 提交 e22701c4 编写于 作者: W WeiXin 提交者: GitHub

delete the function of saving layer object. (#33697)

* delete the function of saving layer object.

* edit doc of paddle.save/load and polish error message
上级 6df7ac72
......@@ -5578,18 +5578,6 @@ class ParamBase(core.VarBase):
core.varbase_copy(self, new_param, device, blocking)
return new_param
def __reduce__(self):
value = self.numpy()
state = (self.name, self.persistable, self.stop_gradient)
return ParamBase, (self.shape, self.dtype), (self.__dict__, value,
state)
def __setstate__(self, state):
self.__dict__.update(state[0])
t = self.value().get_tensor()
t.set(state[1], _current_expected_place())
self.name, self.persistable, self.stop_gradient = state[2]
__repr__ = __str__
......
......@@ -928,15 +928,9 @@ class TestSaveLoadLayer(unittest.TestCase):
origin_layer = (layer1, layer2)
origin = (layer1(inps), layer2(inps))
path = "test_save_load_layer_/layer.pdmodel"
with self.assertRaises(ValueError):
paddle.save(origin_layer, path)
loaded_layer = paddle.load(path)
loaded_result = [l(inps) for l in loaded_layer]
for i in range(len(origin)):
self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10)
for k, v in origin_layer[i]._linear.weight.__dict__.items():
self.assertTrue(v == loaded_layer[i]._linear.weight.__dict__[k])
if __name__ == '__main__':
unittest.main()
......@@ -229,13 +229,9 @@ def _pickle_save(obj, f, protocol):
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(protocol))
list_params = set()
def reduce_varbase(self):
data = self.numpy()
name = self.name
if name in list_params:
return self.__reduce__()
return (tuple, ((name, data), ))
......@@ -245,19 +241,8 @@ def _pickle_save(obj, f, protocol):
return (eval, ('data', {'data': data}))
def reduce_Layer(self):
is_param_or_layer = lambda v: isinstance(v, ParamBase) or isinstance(v, core.Layer)
def collect_params(param_or_layer):
if isinstance(param_or_layer, ParamBase):
list_params.add(param_or_layer.name)
else:
# param_or_layer is layer
_parse_every_object(param_or_layer.__dict__, is_param_or_layer,
collect_params)
return param_or_layer
_parse_every_object(self.__dict__, is_param_or_layer, collect_params)
return self.__reduce_ex__(protocol)
raise ValueError(
"paddle do not support saving `paddle.nn.Layer` object.")
dispatch_table_layer = dict()
......@@ -567,7 +552,7 @@ def save(obj, path, protocol=4, **configs):
Save an object to the specified path.
.. note::
Now supports saving ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program.
Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
.. note::
Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file,
......@@ -783,7 +768,7 @@ def load(path, **configs):
Load an object can be used in paddle from specified path.
.. note::
Now supports loading ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program.
Now supports loading ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
.. note::
In order to use the model parameters saved by paddle more efficiently,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册