未验证 提交 6da9db50 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

fix group_shard3_get_all_parameter (#55572)

上级 db921ae9
...@@ -651,6 +651,11 @@ class GroupShardedStage3(nn.Layer): ...@@ -651,6 +651,11 @@ class GroupShardedStage3(nn.Layer):
for param in trainable_params: for param in trainable_params:
t_flow.full_param[param.name][0]._share_buffer_to(param) t_flow.full_param[param.name][0]._share_buffer_to(param)
# a _allgather_buffer call should be matched with a _release_param call later,
# but the _allgather_buffer call here has no match.
# TODO(liuzhenhai): set a flag here and release full param before forward pass of the first layer,
# when _allgather_buffer is called for get_all_parameters and convert2cpu is false
self._optim._parameter_list = self._ori_parameter_list self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups self._optim._param_groups = self._ori_param_groups
...@@ -924,14 +929,11 @@ class TaskFlow: ...@@ -924,14 +929,11 @@ class TaskFlow:
def __init__( def __init__(
self, self,
full_param={},
full_grad={},
use_calc={},
callback=None, callback=None,
): ):
self.full_param = full_param self.full_param = {}
self.full_grad = full_grad self.full_grad = {}
self.use_calc = use_calc self.use_calc = {}
self.callback = callback self.callback = callback
...@@ -1004,6 +1006,9 @@ def _allgather_buffer( ...@@ -1004,6 +1006,9 @@ def _allgather_buffer(
offload=False, offload=False,
convert2cpu=False, convert2cpu=False,
): ):
if convert2cpu:
assert sync_wait
for param in trainable_params: for param in trainable_params:
if param.status == "all": if param.status == "all":
param.use_count += 1 param.use_count += 1
...@@ -1020,20 +1025,22 @@ def _allgather_buffer( ...@@ -1020,20 +1025,22 @@ def _allgather_buffer(
if sync_wait: if sync_wait:
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
task.wait() task.wait()
full_param._slice(0, param._numel())._share_buffer_to(param) if convert2cpu:
param.fw_storage._clear() # status is not changed
param.fw_storage = None cpu_full_param = _device2cpu(
param.status = "all" full_param._slice(0, param._numel())
param.use_count += 1 )
full_param._clear_data()
del full_param
full_param = cpu_full_param
task = None
else:
full_param._slice(0, param._numel())._share_buffer_to(param)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
task_flow.full_param[param.name] = (full_param, task) task_flow.full_param[param.name] = (full_param, task)
# parameter converts to cpu
if convert2cpu:
p_name = param.name
param = _device2cpu(param)
del task_flow.full_param[p_name]
task_flow.full_param[p_name] = (param, None)
return task_flow return task_flow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册