未验证 提交 34dfb0ec 编写于 作者: B Baibaifan 提交者: GitHub

fix_sharding_grad_clip (#40601)

上级 e52ffb70
...@@ -89,7 +89,7 @@ class ShardingClipGrad: ...@@ -89,7 +89,7 @@ class ShardingClipGrad:
global_norm_fp16 = paddle.cast( global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32) global_norm_fp16, dtype=paddle.float32)
# global norm of non-distributed FP16 params_and_grads for slice parameter # global norm of non-distributed FP16 params_and_grads for unslice parameter
if len(unslice_params_fp16) == 0: if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32) global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else: else:
...@@ -104,21 +104,20 @@ class ShardingClipGrad: ...@@ -104,21 +104,20 @@ class ShardingClipGrad:
[0.], dtype=paddle.float32) [0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32) global_norm_fp32 = layers.reduce_sum(global_norm_fp32)
# global norm of non-distributed FP32 params_and_grads for slice parameter # global norm of non-distributed FP32 params_and_grads for unslice parameter
global_unslice_fp32 = layers.concat(unslice_params_fp32) if len( global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
unslice_params_fp32) != 0 else paddle.to_tensor( unslice_params_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32) [0.], dtype=paddle.float32)
global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32) global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32 global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32 global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1]) dev_id = int(self._device.split(":")[1])
with device_guard(dev_id, "gpu"): with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group) paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var += global_unslice_var
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册