未验证 提交 7d4ce5b3 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of fp16 (#38838)

上级 3a23c1a2
...@@ -1279,8 +1279,36 @@ class Layer(object): ...@@ -1279,8 +1279,36 @@ class Layer(object):
def _obtain_parameters_buffers(self, def _obtain_parameters_buffers(self,
destination=None, destination=None,
include_sublayers=True, include_sublayers=True,
structured_name_prefix="", structured_name_prefix=""):
include_non_persistable_buffer=False): """
The difference from state_dict() is that state_dict_hook will not be called,
but the original types of parameters and buffers will be maintained.
"""
if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[structured_name_prefix + name] = data
for name, buffer in self._buffers.items():
if buffer is not None and name not in self._non_persistable_buffer_names_set:
destination[structured_name_prefix + name] = buffer
if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
if layer_item is not None:
destination_temp = destination.copy()
destination_temp.update(
layer_item._obtain_parameters_buffers(
destination_temp, include_sublayers,
structured_name_prefix + layer_name + "."))
destination = destination_temp
return destination
def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
""" """
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
...@@ -1313,16 +1341,6 @@ class Layer(object): ...@@ -1313,16 +1341,6 @@ class Layer(object):
structured_name_prefix + layer_name + ".", structured_name_prefix + layer_name + ".",
include_non_persistable_buffer)) include_non_persistable_buffer))
destination = destination_temp destination = destination_temp
return destination
def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
destination = self._obtain_parameters_buffers(
destination, include_sublayers, structured_name_prefix,
include_non_persistable_buffer)
for state_dict_hook in self._state_dict_hooks.values(): for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination) hook_result = state_dict_hook(destination)
if hook_result is not None: if hook_result is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册