From ecb8c1847e9e0226644a8f05310a1de206ab32dd Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 29 Dec 2021 18:54:27 +0800 Subject: [PATCH] [BugFix]Fix bug in obtaining parameters_buffers in layers (#38563) * fix bug of dp in pfp16 * fix topo --- python/paddle/fluid/dygraph/layers.py | 19 ++++++++++++++----- python/paddle/fluid/dygraph/parallel.py | 3 +-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index caa30df5c1..1fd408c465 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1271,11 +1271,11 @@ class Layer(object): self._state_dict_hooks[hook_remove_helper._hook_id] = hook return hook_remove_helper - def _state_dict_impl(self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - include_non_persistable_buffer=False): + def _obtain_parameters_buffers(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1308,7 +1308,16 @@ class Layer(object): structured_name_prefix + layer_name + ".", include_non_persistable_buffer)) destination = destination_temp + return destination + def _state_dict_impl(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): + destination = self._obtain_parameters_buffers( + destination, include_sublayers, structured_name_prefix, + include_non_persistable_buffer) for state_dict_hook in self._state_dict_hooks.values(): hook_result = state_dict_hook(destination) if hook_result is not None: diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 809a4d385f..ddb86848f8 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -356,8 +356,7 @@ def sync_params_buffers(model, src_rank=0, is_model_parallel=False): model_vars = [] - params_buffers = model.parameters() + model.buffers() - for param in params_buffers: + for _, param in model._obtain_parameters_buffers().items(): if not isinstance(param, core.VarBase): raise TypeError("The data type of '%s' must be Varbase" % param.name) -- GitLab