From 34dfb0ec4401806c84a1f336b9ebb484e2dbe68a Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Fri, 18 Mar 2022 19:54:00 +0800 Subject: [PATCH] fix_sharding_grad_clip (#40601) --- .../fleet/meta_parallel/sharding/sharding_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 89b59254e5..6a30276e02 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -89,7 +89,7 @@ class ShardingClipGrad: global_norm_fp16 = paddle.cast( global_norm_fp16, dtype=paddle.float32) - # global norm of non-distributed FP16 params_and_grads for slice parameter + # global norm of non-distributed FP16 params_and_grads for unslice parameter if len(unslice_params_fp16) == 0: global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32) else: @@ -104,21 +104,20 @@ class ShardingClipGrad: [0.], dtype=paddle.float32) global_norm_fp32 = layers.reduce_sum(global_norm_fp32) - # global norm of non-distributed FP32 params_and_grads for slice parameter + # global norm of non-distributed FP32 params_and_grads for unslice parameter global_unslice_fp32 = layers.concat(unslice_params_fp32) if len( unslice_params_fp32) != 0 else paddle.to_tensor( [0.], dtype=paddle.float32) global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32) global_unslice_var = global_unslice_fp16 + global_unslice_fp32 - global_norm_var = global_norm_fp16 + global_norm_fp32 + global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var # add all reduce to get global norm of distributed params_and_grads dev_id = int(self._device.split(":")[1]) with device_guard(dev_id, "gpu"): paddle.distributed.all_reduce(global_norm_var, group=self._group) - global_norm_var += global_unslice_var global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) -- GitLab