From 730dcaf48f6b1e0e561860eb503ceef9a9498b59 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 7 Oct 2021 22:06:21 +0800 Subject: [PATCH] fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237) * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer * update * update --- .../hybrid_parallel_optimizer.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 581fbc5153a..b00ef2cdcb0 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -50,7 +50,8 @@ class HybridParallelClipGrad: @imperative_base.no_grad def _dygraph_clip(self, params_grads): params_and_grads = [] - sum_square_list = [] + sum_square_list_dist = [] + sum_square_list_not_dist = [] for p, g in params_grads: if g is None: continue @@ -62,18 +63,33 @@ class HybridParallelClipGrad: merge_grad = layers.get_tensor_from_selected_rows(merge_grad) square = layers.square(merge_grad) sum_square = layers.reduce_sum(square) - sum_square_list.append(sum_square) + + if p.is_distributed: + sum_square_list_dist.append(sum_square) + else: + sum_square_list_not_dist.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) == 0: + if len(sum_square_list_dist) + len(sum_square_list_not_dist) == 0: return params_grads - global_norm_var = layers.concat(sum_square_list) - global_norm_var = layers.reduce_sum(global_norm_var) - # add all reduce to get global norm in world size - paddle.distributed.all_reduce(global_norm_var, - self._hcg.get_check_parallel_group()) - global_norm_var = layers.sqrt(global_norm_var) + global_norm_var_dist = layers.concat(sum_square_list_dist) if len( + sum_square_list_dist) != 0 else layers.concat( + [paddle.to_tensor([0.])]) + global_norm_var_dist = layers.reduce_sum(global_norm_var_dist) + global_norm_var_not_dist = layers.concat( + sum_square_list_not_dist) if len( + sum_square_list_not_dist) != 0 else layers.concat( + [paddle.to_tensor([0.])]) + global_norm_var_not_dist = layers.reduce_sum(global_norm_var_not_dist) + + # add all reduce to get global norm of distributed params_and_grads in world size + # all reduce is not needed while getting global norm of non-distributed params_and_grads + paddle.distributed.all_reduce( + global_norm_var_dist, group=self._hcg.get_check_parallel_group()) + + global_norm_var = layers.sqrt(global_norm_var_dist + + global_norm_var_not_dist) max_global_norm = layers.fill_constant( shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) @@ -96,7 +112,7 @@ class HybridParallelClipGrad: return getattr(self._clip, item) def __call__(self, params_grads): - return self._clip(params_grads) + return self._dygraph_clip(params_grads) class HybridParallelOptimizer: @@ -112,7 +128,7 @@ class HybridParallelOptimizer: self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) # NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization - # is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer. + # is achieved through reducer, so there is no need to call fuse_allreduce in optimizer. self._dp_enable = not self._use_dp_mode and self._need_dp self._sharding_enable = ( -- GitLab