From 3addd568f49d8a3f04a772a4b90ddcf0f52b76cc Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 10 Nov 2022 11:02:10 +0800 Subject: [PATCH] [Dygraph] Support grad division to nranks before reduce in sharding stage2 (#47764) --- .../sharding/group_sharded_optimizer_stage2.py | 15 +++++++++------ .../sharding/group_sharded_stage2.py | 15 ++++++++++++--- .../fleet/dygraph_group_sharded_api_eager.py | 10 ---------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 6f98f5be22..38b0322561 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -498,12 +498,7 @@ class GroupShardedOptimizerStage2(Optimizer): with device_guard(self._rank, self.offload_device): self.offload_grads.buffer.zero_() - def step(self): - """ - A wrapper for Optimizer's step function to finish the update operation of the optimizer. - """ - # This method won't be called directly by opt.step()! - # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function. + def _step(self): if self._broadcast_overlap: # Clear the pre forward hook in the optimizer step. for hook_remove in self._forward_pre_hook_remove_helper: @@ -536,6 +531,14 @@ class GroupShardedOptimizerStage2(Optimizer): # Synchronize all the updated shards in between the ranks self._broadcast_params() + def step(self): + """ + A wrapper for Optimizer's step function to finish the update operation of the optimizer. + """ + # This method won't be called directly by opt.step()! + # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function. + self._step() + def minimize(self): raise RuntimeError( "optimizer.minimize() not support now, please use optimizer.step()" diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index b28ba66b67..05a25223e6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -225,13 +225,13 @@ class GroupShardedStage2(nn.Layer): def _grad_scale(self): """ - Before the gradient accumulation, scale the gradient. + Before the optimization, scale the gradients before allreduce of dp_group. """ if self._dp_group is None or self._dp_group.nranks <= 1: - scale_factor = self._world_size_scaling + return else: - scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks) + scale_factor = 1.0 / (self._dp_group.nranks) # Scale grad storages for dtype in self._grad_storages.keys(): @@ -366,6 +366,13 @@ class GroupShardedStage2(nn.Layer): ), "Only support comm overlap strategy for single optimizer" self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap) + def _get_scaled_grad_fn(self): + @paddle.autograd.no_grad() + def scale(grad): + grad.scale_(self._world_size_scaling) + + return scale + def _get_reduce_fn(self, index, param, dst_rank): """ There are two ways to reduce gradient. @@ -510,6 +517,8 @@ class GroupShardedStage2(nn.Layer): return for index, param in enumerate(self._trainable_params): + param._register_grad_hook(self._get_scaled_grad_fn()) + dst_rank = self._trainable_param2rank[param.name] reduce_function = self._get_reduce_fn(index, param, dst_rank) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py index 7b3ef6236f..04c13f358b 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_api_eager.py @@ -153,16 +153,6 @@ def test_sharding_api(): list(range(paddle.distributed.get_world_size())) ) - stage2_dp_params = train_mlp( - mlp1, - shard_level="os_g", - use_multi_precision=True, - output_dir=output_dir, - amp_level='O2', - sync_buffers=True, - dp_group=dp_group, - ) - # fp16 stage2_params = train_mlp( mlp1, -- GitLab