diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index f6b86ce736d7858895c8fa6db0f9feaf807b5dcb..b9ca53aeef0a10f818bd8915dc7b60f982a6be9f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -651,6 +651,11 @@ class GroupShardedStage3(nn.Layer): for param in trainable_params: t_flow.full_param[param.name][0]._share_buffer_to(param) + # a _allgather_buffer call should be matched with a _release_param call later, + # but the _allgather_buffer call here has no match. + # TODO(liuzhenhai): set a flag here and release full param before forward pass of the first layer, + # when _allgather_buffer is called for get_all_parameters and convert2cpu is false + self._optim._parameter_list = self._ori_parameter_list self._optim._param_groups = self._ori_param_groups @@ -924,14 +929,11 @@ class TaskFlow: def __init__( self, - full_param={}, - full_grad={}, - use_calc={}, callback=None, ): - self.full_param = full_param - self.full_grad = full_grad - self.use_calc = use_calc + self.full_param = {} + self.full_grad = {} + self.use_calc = {} self.callback = callback @@ -1004,6 +1006,9 @@ def _allgather_buffer( offload=False, convert2cpu=False, ): + if convert2cpu: + assert sync_wait + for param in trainable_params: if param.status == "all": param.use_count += 1 @@ -1020,20 +1025,22 @@ def _allgather_buffer( if sync_wait: with paddle.amp.auto_cast(enable=False): task.wait() - full_param._slice(0, param._numel())._share_buffer_to(param) - param.fw_storage._clear() - param.fw_storage = None - param.status = "all" - param.use_count += 1 + if convert2cpu: + # status is not changed + cpu_full_param = _device2cpu( + full_param._slice(0, param._numel()) + ) + full_param._clear_data() + del full_param + full_param = cpu_full_param + task = None + else: + full_param._slice(0, param._numel())._share_buffer_to(param) + param.fw_storage._clear() + param.fw_storage = None + param.status = "all" + param.use_count += 1 task_flow.full_param[param.name] = (full_param, task) - - # parameter converts to cpu - if convert2cpu: - p_name = param.name - param = _device2cpu(param) - del task_flow.full_param[p_name] - task_flow.full_param[p_name] = (param, None) - return task_flow