diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 8ae20c91bb4155ea73765b533ebad72357996dfb..f1756bc02055c20d43ea053cdef3169f35f789c0 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -18,10 +18,21 @@ import numpy as np import paddle from paddle import _legacy_C_ops -from paddle.distributed.parallel import _split_tensors +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import ( + GradStorage, +) from paddle.fluid import core from paddle.framework import base as imperative_base +alignment = { + "gpu": 256, +} +align = { + paddle.float16.value: 2, + paddle.bfloat16.value: 2, + paddle.float32.value: 4, +} + __all__ = [] @@ -120,26 +131,47 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ) +def flatten_dense_tensors(parameters, use_main_grad=False): + _buffer_size = 0 + _param2align = {} + dtype = paddle.float32 if use_main_grad else parameters[0].dtype + + for param in parameters: + assert param.trainable, "param must be trainable..." + size = np.prod(param.shape) * align[dtype] + remaining = size % alignment["gpu"] + ali = 0 if remaining == 0 else alignment["gpu"] - remaining + align_ = ali // align[dtype] + _buffer_size += np.prod(param.shape) + align_ + _param2align[param.name] = align_ + + # process gradient + grad_storage = GradStorage( + size=_buffer_size, + dtype=dtype, + device="gpu", + destination="0", + parm2align=_param2align, + ) + + for param in parameters: + grad_storage.add_grad(param, _param2align[param.name]) + + return grad_storage.buffer + + class FusedCommBuffer: - def __init__( - self, - id, - params, - comm_group, - acc_steps=1, - act=None, - dst=-1, - ): + def __init__(self, id, params, comm_group, acc_steps=1, act=None, dst=-1): self._id = id self._params = params self._acc_steps = acc_steps self._comm_group = comm_group - self._tasks = [] - self._grads = [] + use_main_grad = hasattr(self._params[0], "main_grad") + + self._task = None self._params_step_dict = {} self._params_checked_in = 0 - self._coalesced_grads_and_grad_vars = [] self._act = act if self._act == HOOK_ACTION.ALL_REDUCE: @@ -154,16 +186,16 @@ class FusedCommBuffer: self._init_step_dict() + self.grad_storage = flatten_dense_tensors(self._params, use_main_grad) + def _init_step_dict(self): for p in self._params: self._params_step_dict[p.name] = 0 def _reset_params_checked_in(self): - self._tasks.clear() - self._grads.clear() + self._task = None self._init_step_dict() self._params_checked_in = 0 - self._coalesced_grads_and_grad_vars.clear() @property def _all_params_checked_in(self): @@ -175,13 +207,6 @@ class FusedCommBuffer: def add_grad(self, param): assert param.name in self._params_step_dict - if self._params_step_dict[param.name] == 0: - if getattr(param, "main_grad", None) is not None: - assert param.grad is None - self._grads.append(param.main_grad) - else: - self._grads.append(param.grad) - self._params_step_dict[param.name] += 1 if self._params_step_dict[param.name] == self._acc_steps: @@ -189,49 +214,33 @@ class FusedCommBuffer: self._params_step_dict.pop(param.name) if self._all_params_checked_in: - self._fused_comm_grads() + self._comm_grads() @imperative_base.no_grad - def _fused_comm_grads(self): + def _comm_grads(self): assert self._all_params_checked_in - flattened_vars = [] - g_var_shapes = [] - for g_var in self._grads: - g_var_shapes.append(g_var.shape) - flattened_vars.append( - paddle.reshape(x=g_var, shape=[np.prod(g_var.shape)]) + if self._act == HOOK_ACTION.ALL_REDUCE: + task = paddle.distributed.all_reduce( + self.grad_storage, group=self._comm_group, sync_op=False ) - - coalesced_grad = paddle.concat(flattened_vars) - self._coalesced_grads_and_grad_vars.append( - [coalesced_grad, self._grads, g_var_shapes] - ) - - for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: - if self._act == HOOK_ACTION.ALL_REDUCE: - task = paddle.distributed.all_reduce( - coalesced_grad, group=self._comm_group, sync_op=False - ) - elif self._act == HOOK_ACTION.REDUCE: - task = paddle.distributed.reduce( - coalesced_grad, - dst=self._dst, - group=self._comm_group, - sync_op=False, - ) - self._tasks.append(task) + elif self._act == HOOK_ACTION.REDUCE: + task = paddle.distributed.reduce( + self.grad_storage, + dst=self._dst, + group=self._comm_group, + sync_op=False, + ) + self._task = task @imperative_base.no_grad def scale_and_split_grads(self): - for task in self._tasks: - task.wait() + assert self._task is not None + self._task.wait() scale_factor = 1.0 / self._comm_group.nranks - for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: - coalesced_grad.scale_(scale_factor) + self.grad_storage.scale_(scale_factor) - _split_tensors(self._coalesced_grads_and_grad_vars) self._reset_params_checked_in() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 73e1b9a9781a5be50c0f2fd84c14108242411352..44c5995acc7963aad7d63deb1c1a868c37d75681 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -315,7 +315,12 @@ class GradStorage(InternalStorage): assert ( param._numel() > 0 ), "Cannot add a gradient to a released InternalStorage, please rebuild" - assert param.dtype == self.buffer.dtype + + use_main_grad = hasattr(param, "main_grad") + if use_main_grad: + assert self.buffer.dtype == paddle.float32 + else: + assert param.dtype == self.buffer.dtype grad_end = self._fill + param._numel() offset = grad_end + align @@ -325,7 +330,10 @@ class GradStorage(InternalStorage): with device_guard(self.dev_id, self._device): tmp_var = self.buffer._slice(self._fill, grad_end) tmp_var.get_tensor()._set_dims(param.shape) - param._copy_gradient_from(tmp_var) + if not use_main_grad: + param._copy_gradient_from(tmp_var) + else: + param.main_grad = tmp_var del tmp_var self._fill = offset