diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 695c91fea819f57a12ec760d3eeb4965da6c23de..22f31a340364f931f98979af7d2c6b1f680f837e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5540,6 +5540,18 @@ class ParamBase(core.VarBase): core.varbase_copy(self, new_param, device, blocking) return new_param + def __reduce__(self): + value = self.numpy() + state = (self.name, self.persistable, self.stop_gradient) + return ParamBase, (self.shape, self.dtype), (self.__dict__, value, + state) + + def __setstate__(self, state): + self.__dict__.update(state[0]) + t = self.value().get_tensor() + t.set(state[1], _current_expected_place()) + self.name, self.persistable, self.stop_gradient = state[2] + __repr__ = __str__ diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 594d0db035c6a5f71a5e07ca9547e66cfe58771e..fe8692a38814e92bba93fb82f103a77b13cd9153 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -935,21 +935,17 @@ class TestSaveLoadLayer(unittest.TestCase): layer2 = LinearNet() layer1.eval() layer2.eval() + origin_layer = (layer1, layer2) origin = (layer1(inps), layer2(inps)) path = "test_save_load_layer_/layer.pdmodel" - paddle.save((layer1, layer2), path) - - # static - paddle.enable_static() - with self.assertRaises(ValueError): - paddle.load(path) - # dygraph - paddle.disable_static() + paddle.save(origin_layer, path) loaded_layer = paddle.load(path) loaded_result = [l(inps) for l in loaded_layer] for i in range(len(origin)): self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10) + for k, v in origin_layer[i]._linear.weight.__dict__.items(): + self.assertTrue(v == loaded_layer[i]._linear.weight.__dict__[k]) if __name__ == '__main__': diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 5f1ffa81eab17b720f9f02a9d55a8720d64aa27d..d02d078d547deaeb29b59a21e28c4971ae5b9a2d 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -233,9 +233,13 @@ def _pickle_save(obj, f, protocol): raise ValueError("Expected 1<'protocol'<5, but received protocol={}". format(protocol)) - def reudce_varbase(self): + list_params = set() + + def reduce_varbase(self): data = self.numpy() name = self.name + if name in list_params: + return self.__reduce__() return (tuple, ((name, data), )) @@ -244,16 +248,43 @@ def _pickle_save(obj, f, protocol): return (eval, ('data', {'data': data})) + def reduce_Layer(self): + is_param_or_layer = lambda v: isinstance(v, ParamBase) or isinstance(v, core.Layer) + + def collect_params(param_or_layer): + if isinstance(param_or_layer, ParamBase): + list_params.add(param_or_layer.name) + else: + # param_or_layer is layer + _parse_every_object(param_or_layer.__dict__, is_param_or_layer, + collect_params) + return param_or_layer + + _parse_every_object(self.__dict__, is_param_or_layer, collect_params) + return self.__reduce_ex__(protocol) + + dispatch_table_layer = dict() + + def create_layer_dispatch_table(layer): + dispatch_table_layer[layer.__class__] = reduce_Layer + return layer + + _parse_every_object(obj, lambda v: isinstance(v, core.Layer), + create_layer_dispatch_table) + def add_dispatch_table(): # This is not a good method, because the pickle module has been modified. - pickle.dispatch_table[core.VarBase] = reudce_varbase - pickle.dispatch_table[ParamBase] = reudce_varbase + pickle.dispatch_table[core.VarBase] = reduce_varbase + pickle.dispatch_table[ParamBase] = reduce_varbase pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor + pickle.dispatch_table.update(dispatch_table_layer) def pop_dispatch_table(): pickle.dispatch_table.pop(core.VarBase) pickle.dispatch_table.pop(core.LoDTensor) pickle.dispatch_table.pop(ParamBase) + for k in dispatch_table_layer: + pickle.dispatch_table.pop(k) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' if sys.platform == 'darwin' and sys.version_info.major == 3: @@ -273,10 +304,10 @@ def _pickle_save(obj, f, protocol): pickler = pickle.Pickler(f, protocol) pickler.dispatch_table = copyreg.dispatch_table.copy() - pickler.dispatch_table[core.VarBase] = reudce_varbase + pickler.dispatch_table[core.VarBase] = reduce_varbase pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor - pickler.dispatch_table[ParamBase] = reudce_varbase - + pickler.dispatch_table[ParamBase] = reduce_varbase + pickler.dispatch_table.update(dispatch_table_layer) pickler.dump(obj)