未验证 提交 28521e0f 编写于 作者: W WeiXin 提交者: GitHub

Save all the information of 'ParamBase' in 'Layer'. (#33500)

* Save all the information of 'ParamBase' in 'Layer'.

* edit unittest
上级 009a163c
...@@ -5540,6 +5540,18 @@ class ParamBase(core.VarBase): ...@@ -5540,6 +5540,18 @@ class ParamBase(core.VarBase):
core.varbase_copy(self, new_param, device, blocking) core.varbase_copy(self, new_param, device, blocking)
return new_param return new_param
def __reduce__(self):
value = self.numpy()
state = (self.name, self.persistable, self.stop_gradient)
return ParamBase, (self.shape, self.dtype), (self.__dict__, value,
state)
def __setstate__(self, state):
self.__dict__.update(state[0])
t = self.value().get_tensor()
t.set(state[1], _current_expected_place())
self.name, self.persistable, self.stop_gradient = state[2]
__repr__ = __str__ __repr__ = __str__
......
...@@ -935,21 +935,17 @@ class TestSaveLoadLayer(unittest.TestCase): ...@@ -935,21 +935,17 @@ class TestSaveLoadLayer(unittest.TestCase):
layer2 = LinearNet() layer2 = LinearNet()
layer1.eval() layer1.eval()
layer2.eval() layer2.eval()
origin_layer = (layer1, layer2)
origin = (layer1(inps), layer2(inps)) origin = (layer1(inps), layer2(inps))
path = "test_save_load_layer_/layer.pdmodel" path = "test_save_load_layer_/layer.pdmodel"
paddle.save((layer1, layer2), path) paddle.save(origin_layer, path)
# static
paddle.enable_static()
with self.assertRaises(ValueError):
paddle.load(path)
# dygraph
paddle.disable_static()
loaded_layer = paddle.load(path) loaded_layer = paddle.load(path)
loaded_result = [l(inps) for l in loaded_layer] loaded_result = [l(inps) for l in loaded_layer]
for i in range(len(origin)): for i in range(len(origin)):
self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10) self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10)
for k, v in origin_layer[i]._linear.weight.__dict__.items():
self.assertTrue(v == loaded_layer[i]._linear.weight.__dict__[k])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -233,9 +233,13 @@ def _pickle_save(obj, f, protocol): ...@@ -233,9 +233,13 @@ def _pickle_save(obj, f, protocol):
raise ValueError("Expected 1<'protocol'<5, but received protocol={}". raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
format(protocol)) format(protocol))
def reudce_varbase(self): list_params = set()
def reduce_varbase(self):
data = self.numpy() data = self.numpy()
name = self.name name = self.name
if name in list_params:
return self.__reduce__()
return (tuple, ((name, data), )) return (tuple, ((name, data), ))
...@@ -244,16 +248,43 @@ def _pickle_save(obj, f, protocol): ...@@ -244,16 +248,43 @@ def _pickle_save(obj, f, protocol):
return (eval, ('data', {'data': data})) return (eval, ('data', {'data': data}))
def reduce_Layer(self):
is_param_or_layer = lambda v: isinstance(v, ParamBase) or isinstance(v, core.Layer)
def collect_params(param_or_layer):
if isinstance(param_or_layer, ParamBase):
list_params.add(param_or_layer.name)
else:
# param_or_layer is layer
_parse_every_object(param_or_layer.__dict__, is_param_or_layer,
collect_params)
return param_or_layer
_parse_every_object(self.__dict__, is_param_or_layer, collect_params)
return self.__reduce_ex__(protocol)
dispatch_table_layer = dict()
def create_layer_dispatch_table(layer):
dispatch_table_layer[layer.__class__] = reduce_Layer
return layer
_parse_every_object(obj, lambda v: isinstance(v, core.Layer),
create_layer_dispatch_table)
def add_dispatch_table(): def add_dispatch_table():
# This is not a good method, because the pickle module has been modified. # This is not a good method, because the pickle module has been modified.
pickle.dispatch_table[core.VarBase] = reudce_varbase pickle.dispatch_table[core.VarBase] = reduce_varbase
pickle.dispatch_table[ParamBase] = reudce_varbase pickle.dispatch_table[ParamBase] = reduce_varbase
pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickle.dispatch_table.update(dispatch_table_layer)
def pop_dispatch_table(): def pop_dispatch_table():
pickle.dispatch_table.pop(core.VarBase) pickle.dispatch_table.pop(core.VarBase)
pickle.dispatch_table.pop(core.LoDTensor) pickle.dispatch_table.pop(core.LoDTensor)
pickle.dispatch_table.pop(ParamBase) pickle.dispatch_table.pop(ParamBase)
for k in dispatch_table_layer:
pickle.dispatch_table.pop(k)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if sys.platform == 'darwin' and sys.version_info.major == 3: if sys.platform == 'darwin' and sys.version_info.major == 3:
...@@ -273,10 +304,10 @@ def _pickle_save(obj, f, protocol): ...@@ -273,10 +304,10 @@ def _pickle_save(obj, f, protocol):
pickler = pickle.Pickler(f, protocol) pickler = pickle.Pickler(f, protocol)
pickler.dispatch_table = copyreg.dispatch_table.copy() pickler.dispatch_table = copyreg.dispatch_table.copy()
pickler.dispatch_table[core.VarBase] = reudce_varbase pickler.dispatch_table[core.VarBase] = reduce_varbase
pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickler.dispatch_table[ParamBase] = reudce_varbase pickler.dispatch_table[ParamBase] = reduce_varbase
pickler.dispatch_table.update(dispatch_table_layer)
pickler.dump(obj) pickler.dump(obj)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册