diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index ad4d53cb08254e0a41934405b3a963756a7c1053..3d3debb252d400ddf3962f064682cf1b829af131 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -85,6 +85,7 @@ class GroupShardedStage3(nn.Layer): pertrain_sync_models=True, offload=False, sync_comm=False, + dp_group=None, ): super().__init__() @@ -120,6 +121,7 @@ class GroupShardedStage3(nn.Layer): if group is None else group ) + self._dp_group = dp_group self._world_size_scaling = 1.0 / self._group.nranks assert ( self._group.nranks > 1 @@ -201,6 +203,13 @@ class GroupShardedStage3(nn.Layer): dist.broadcast( p, src=self._global_root_rank, group=self._group, sync_op=True ) + if self._dp_group is not None and self._dp_group.nranks > 1: + dist.broadcast( + p, + src=self._dp_group.ranks[0], + group=self._dp_group, + sync_op=True, + ) def _clear_gradients(self): assert len(self._trainable_params.keys()) > 0 @@ -502,6 +511,13 @@ class GroupShardedStage3(nn.Layer): dist.broadcast( buffer, self._global_root_rank, self._group, sync_op=True ) + if self._dp_group is not None and self._dp_group.nranks > 1: + dist.broadcast( + buffer, + self._dp_group.ranks[0], + self._dp_group, + sync_op=True, + ) def __getattr__(self, name): """Forward missing attributes to wrapped layer.""" @@ -528,12 +544,7 @@ class GroupShardedStage3(nn.Layer): assert hasattr( param, "fw_storage" ), "Find {} don't have fw_storage attribute".format(param.name) - # Gradient average - if self._offload: - with device_guard(): - param.bw_storage.scale_(scale=self._world_size_scaling) - else: - param.bw_storage.scale_(scale=self._world_size_scaling) + param.fw_storage = _VarBaseWrapper(param) assert param.fw_storage.grad is None param.fw_storage._copy_gradient_from(param.bw_storage) @@ -543,6 +554,12 @@ class GroupShardedStage3(nn.Layer): for grad_storage in self._grad_storages.values(): grad_storage.buffer.scale_(scale=self._world_size_scaling) dist.all_reduce(tensor=grad_storage.buffer, group=self._group) + if self._dp_group is not None and self._dp_group.nranks > 1: + grad_storage.buffer.scale_(scale=(1.0 / self._dp_group.nranks)) + dist.all_reduce( + tensor=grad_storage.buffer, group=self._dp_group + ) + if self._offload: for param in list(self._unslice_params): param._clear_data() @@ -609,7 +626,11 @@ class GroupShardedStage3(nn.Layer): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] # Only support sync allreduce current rank's layer now + full_grad.scale_(scale=self._world_size_scaling) dist.all_reduce(tensor=full_grad, group=self._group) + if self._dp_group is not None and self._dp_group.nranks > 1: + full_grad.scale_(scale=1.0 / self._dp_group.nranks) + dist.all_reduce(tensor=full_grad, group=self._dp_group) start, end = self._param2buffer[param.name][self._rank] if param.bw_storage is None: diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index d845f3b78c6345e3de360f2a68bd1af4e40e0269..620540fea58761f8930b33bd8d65f6bafc7ff369 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -119,11 +119,7 @@ class GroupShardedClipGrad: global_unslice_fp32 = paddle.sum(global_unslice_fp32) global_unslice_var = global_unslice_fp16 + global_unslice_fp32 - global_norm_var = ( - global_norm_fp16 - + global_norm_fp32 - + 1.0 / self._group.nranks * global_unslice_var - ) + global_norm_var = global_norm_fp16 + global_norm_fp32 # add all reduce to get global norm of distributed params_and_grads dev_id = int(self._device.split(":")[1]) @@ -133,7 +129,7 @@ class GroupShardedClipGrad: with device_guard(dev_id, self._device.split(":")[0]): paddle.distributed.all_reduce(global_norm_var, group=self._group) - global_norm_var = paddle.sqrt(global_norm_var) + global_norm_var = paddle.sqrt(global_norm_var + global_unslice_var) max_global_norm = layers.fill_constant( shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm ) @@ -150,9 +146,9 @@ class GroupShardedClipGrad: origin_state = g.stop_gradient g.stop_gradient = True if p.dtype == paddle.float16: - g.scale_(clip_var_fp16.item()) + g.scale_(clip_var_fp16) else: - g.scale_(clip_var.item()) + g.scale_(clip_var) g.stop_gradient = origin_state # p._reset_grad_inplace_version(True) diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index 9951594c2f5cb595756538cb069638bf4efc35d8..703e14f4a3adb146682e61a43f6c0bdad8d85fe0 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -79,7 +79,7 @@ def group_sharded_parallel( buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23. segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. - dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now. + dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication. Returns: model: A wrapper for group sharded given model. @@ -192,6 +192,7 @@ def group_sharded_parallel( segment_size=segment_size, offload=offload, sync_comm=sync_comm, + dp_group=dp_group, device=device, ) else: