未验证 提交 6da9db50 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

fix group_shard3_get_all_parameter (#55572)

上级 db921ae9
......@@ -651,6 +651,11 @@ class GroupShardedStage3(nn.Layer):
for param in trainable_params:
t_flow.full_param[param.name][0]._share_buffer_to(param)
# a _allgather_buffer call should be matched with a _release_param call later,
# but the _allgather_buffer call here has no match.
# TODO(liuzhenhai): set a flag here and release full param before forward pass of the first layer,
# when _allgather_buffer is called for get_all_parameters and convert2cpu is false
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
......@@ -924,14 +929,11 @@ class TaskFlow:
def __init__(
self,
full_param={},
full_grad={},
use_calc={},
callback=None,
):
self.full_param = full_param
self.full_grad = full_grad
self.use_calc = use_calc
self.full_param = {}
self.full_grad = {}
self.use_calc = {}
self.callback = callback
......@@ -1004,6 +1006,9 @@ def _allgather_buffer(
offload=False,
convert2cpu=False,
):
if convert2cpu:
assert sync_wait
for param in trainable_params:
if param.status == "all":
param.use_count += 1
......@@ -1020,20 +1025,22 @@ def _allgather_buffer(
if sync_wait:
with paddle.amp.auto_cast(enable=False):
task.wait()
if convert2cpu:
# status is not changed
cpu_full_param = _device2cpu(
full_param._slice(0, param._numel())
)
full_param._clear_data()
del full_param
full_param = cpu_full_param
task = None
else:
full_param._slice(0, param._numel())._share_buffer_to(param)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
task_flow.full_param[param.name] = (full_param, task)
# parameter converts to cpu
if convert2cpu:
p_name = param.name
param = _device2cpu(param)
del task_flow.full_param[p_name]
task_flow.full_param[p_name] = (param, None)
return task_flow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册