diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 4a60bdc4c72d307bd94d174ec3301c397a4bfe70..4c37a378e0aaefba07cb61a8b8806019fefc04f1 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1279,8 +1279,36 @@ class Layer(object): def _obtain_parameters_buffers(self, destination=None, include_sublayers=True, - structured_name_prefix="", - include_non_persistable_buffer=False): + structured_name_prefix=""): + """ + The difference from state_dict() is that state_dict_hook will not be called, + but the original types of parameters and buffers will be maintained. + """ + if destination is None: + destination = collections.OrderedDict() + for name, data in self._parameters.items(): + if data is not None: + destination[structured_name_prefix + name] = data + for name, buffer in self._buffers.items(): + if buffer is not None and name not in self._non_persistable_buffer_names_set: + destination[structured_name_prefix + name] = buffer + + if include_sublayers: + for layer_name, layer_item in self._sub_layers.items(): + if layer_item is not None: + destination_temp = destination.copy() + destination_temp.update( + layer_item._obtain_parameters_buffers( + destination_temp, include_sublayers, + structured_name_prefix + layer_name + ".")) + destination = destination_temp + return destination + + def _state_dict_impl(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1313,16 +1341,6 @@ class Layer(object): structured_name_prefix + layer_name + ".", include_non_persistable_buffer)) destination = destination_temp - return destination - - def _state_dict_impl(self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - include_non_persistable_buffer=False): - destination = self._obtain_parameters_buffers( - destination, include_sublayers, structured_name_prefix, - include_non_persistable_buffer) for state_dict_hook in self._state_dict_hooks.values(): hook_result = state_dict_hook(destination) if hook_result is not None: