diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index caa30df5c1ffd4e6ce2043b7952db2194b3d5c76..1fd408c465efe7abaa0b339e4362e87e84388f29 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1271,11 +1271,11 @@ class Layer(object): self._state_dict_hooks[hook_remove_helper._hook_id] = hook return hook_remove_helper - def _state_dict_impl(self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - include_non_persistable_buffer=False): + def _obtain_parameters_buffers(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1308,7 +1308,16 @@ class Layer(object): structured_name_prefix + layer_name + ".", include_non_persistable_buffer)) destination = destination_temp + return destination + def _state_dict_impl(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): + destination = self._obtain_parameters_buffers( + destination, include_sublayers, structured_name_prefix, + include_non_persistable_buffer) for state_dict_hook in self._state_dict_hooks.values(): hook_result = state_dict_hook(destination) if hook_result is not None: diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 809a4d385ff9812bd0a913cdfbfd86822a68479d..ddb86848f842a85acc12dca1044a594c484c06fe 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -356,8 +356,7 @@ def sync_params_buffers(model, src_rank=0, is_model_parallel=False): model_vars = [] - params_buffers = model.parameters() + model.buffers() - for param in params_buffers: + for _, param in model._obtain_parameters_buffers().items(): if not isinstance(param, core.VarBase): raise TypeError("The data type of '%s' must be Varbase" % param.name)