未验证 提交 34dfb0ec 编写于 作者: B Baibaifan 提交者: GitHub

fix_sharding_grad_clip (#40601)

上级 e52ffb70
......@@ -89,7 +89,7 @@ class ShardingClipGrad:
global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32)
# global norm of non-distributed FP16 params_and_grads for slice parameter
# global norm of non-distributed FP16 params_and_grads for unslice parameter
if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
......@@ -104,21 +104,20 @@ class ShardingClipGrad:
[0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32)
# global norm of non-distributed FP32 params_and_grads for slice parameter
# global norm of non-distributed FP32 params_and_grads for unslice parameter
global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
unslice_params_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var
# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var += global_unslice_var
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册