未验证 提交 ecb8c184 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix bug in obtaining parameters_buffers in layers (#38563)

* fix bug of dp in pfp16

* fix topo
上级 2fb1fc0d
......@@ -1271,11 +1271,11 @@ class Layer(object):
self._state_dict_hooks[hook_remove_helper._hook_id] = hook
return hook_remove_helper
def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
def _obtain_parameters_buffers(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
......@@ -1308,7 +1308,16 @@ class Layer(object):
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer))
destination = destination_temp
return destination
def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
destination = self._obtain_parameters_buffers(
destination, include_sublayers, structured_name_prefix,
include_non_persistable_buffer)
for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination)
if hook_result is not None:
......
......@@ -356,8 +356,7 @@ def sync_params_buffers(model,
src_rank=0,
is_model_parallel=False):
model_vars = []
params_buffers = model.parameters() + model.buffers()
for param in params_buffers:
for _, param in model._obtain_parameters_buffers().items():
if not isinstance(param, core.VarBase):
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册