diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index f9221f4bb7621a659a405ec352df17eb2d287133..4a5d6c85f855b3e6b5ec97addcd873f6c9e68d96 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -69,6 +69,7 @@ class GroupShardedOptimizerStage2(Optimizer): offload=False, device="gpu", pertrain_sync_models=True, + dp_group=None, **kw): super().__init__(learning_rate=optim._learning_rate, parameters=params) @@ -121,6 +122,8 @@ class GroupShardedOptimizerStage2(Optimizer): self._group = new_group( _get_global_group().ranks) if group is None else group + # only support to combine stage2 and dp hybrid parallel now. + self._dp_group = dp_group self.world_size = self._group.nranks self._rank = self._group.rank self._global_root_rank = self._group.ranks[0] @@ -172,6 +175,12 @@ class GroupShardedOptimizerStage2(Optimizer): group=self._group, sync_op=True) + if self._dp_group: + broadcast(p, + src=self._dp_group.ranks[0], + group=self._dp_group, + sync_op=True) + def _update_task(self, task): if self._reduce_overlap: assert task is not None diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index a2177df7c516b76afdd970d3484b4c9fb9f7ac4c..573e0b597c8fb1ad527862986be10e87a4f74733 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -65,7 +65,8 @@ class GroupShardedStage2(nn.Layer): sync_buffers=False, buffer_max_size=2**23, #8MB auto_refresh_trainable=True, - device="gpu"): + device="gpu", + dp_group=None): super().__init__() # training options @@ -91,6 +92,8 @@ class GroupShardedStage2(nn.Layer): 0] # picking ranks index 0 as the reference self._default_device = device + self._dp_group = dp_group + # Global statistical parameters self._all_params = [] for optim in self._sharding_optimizers: @@ -201,24 +204,29 @@ class GroupShardedStage2(nn.Layer): """ Before the gradient accumulation, scale the gradient. """ + + if self._dp_group is None: + scale_factor = self._world_size_scaling + else: + scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks) + # Scale grad storages for dtype in self._grad_storages.keys(): if not self._offload and self._rank in self._grad_storages[ dtype].keys(): self._grad_storages[dtype][self._rank].buffer.scale_( - scale=self._world_size_scaling) + scale=scale_factor) # Scale grads of params with paddle.no_grad(): for param in self._trainable_params: if param.name in self._param_grads and param.grad is not None: - param.grad.scale_(scale=self._world_size_scaling) + param.grad.scale_(scale=scale_factor) # param._reset_grad_inplace_version(True) # Scale grads of master params with offload strategy if self._offload: - self._sharding_optimizers[0]._offload_scale_grad( - self._world_size_scaling) + self._sharding_optimizers[0]._offload_scale_grad(scale_factor) def _init_internal_storage(self, needs_fresh): """ @@ -288,6 +296,12 @@ class GroupShardedStage2(nn.Layer): self._group, sync_op=True) + if self._dp_group: + collective.broadcast(buffer, + self._dp_group.ranks[0], + self._dp_group, + sync_op=True) + def __getattr__(self, name): """Forward missing attributes to wrapped layer.""" try: @@ -355,6 +369,13 @@ class GroupShardedStage2(nn.Layer): group=self._group, sync_op=not self._reduce_overlap)) + if self._dp_group: + assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' + #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. + collective.all_reduce(tensor=param.grad, + group=self._dp_group, + sync_op=True) + # Clear the task flow and trigger callback to clear the redundant gradient # self._clear_task_flow() @@ -405,6 +426,13 @@ class GroupShardedStage2(nn.Layer): group=self._group, sync_op=not self._reduce_overlap)) + if self._dp_group: + assert not self._comm_overlap, 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.' + #TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2. + collective.all_reduce(tensor=grad_storage.buffer, + group=self._dp_group, + sync_op=True) + cleanup() # Clear the task flow and trigger callback to clear the redundant gradient diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index 144813f5585a9fb8abc46688a312ddc072a33933..1474f639547fb6ae1cb785ba238e1e6f628c8625 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -45,7 +45,8 @@ def group_sharded_parallel(model, sync_buffers=False, buffer_max_size=2**23, segment_size=2**20, - sync_comm=False): + sync_comm=False, + dp_group=None): """ Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation. @@ -61,6 +62,7 @@ def group_sharded_parallel(model, buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23. segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. + dp_group(Group, optional): dp communication group, only support to combine stage2 and dp hybrid communication now. Returns: model: A wrapper for group sharded given model. @@ -123,12 +125,14 @@ def group_sharded_parallel(model, params=optimizer._parameter_list, optim=optimizer, group=group, - offload=offload) + offload=offload, + dp_group=dp_group) model = GroupShardedStage2(model, optimizer, group=group, sync_buffers=sync_buffers, - buffer_max_size=buffer_max_size) + buffer_max_size=buffer_max_size, + dp_group=dp_group) else: optimizer = ShardingOptimizerStage2(params=model.parameters(), optim=optimizer,