From 65f7fa0dbeccc5be8e6f9a6cfad422fff60659ea Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Mon, 27 Dec 2021 19:01:25 +0800 Subject: [PATCH] Refine clip_by_global_norm (#38209) * refine clip * delete unused code * refine logic for clip --- python/paddle/fluid/clip.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index aab5080236..082a72af79 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -479,29 +479,30 @@ class ClipGradByGlobalNorm(ClipGradBase): sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" global_norm_var = [] if len(sum_square_list_fp16) > 0: - global_norm_var_fp16 = layers.concat(sum_square_list_fp16) - global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16) + global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16) global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) if len(sum_square_list_fp32) > 0: - global_norm_var_fp32 = layers.concat(sum_square_list_fp32) - global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32) + global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) if sum_dtype == 'float32': global_norm_var.append(global_norm_var_fp32) else: global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) if len(sum_square_list) > 0: - global_norm_var_fp64 = layers.concat(sum_square_list) - global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64) + global_norm_var_fp64 = paddle.add_n(sum_square_list) global_norm_var.append(global_norm_var_fp64) - global_norm_var = layers.concat(global_norm_var) - global_norm_var = layers.reduce_sum(global_norm_var) + global_norm_var = paddle.add_n(global_norm_var) global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) - clip_var = layers.elementwise_div( - x=max_global_norm, - y=layers.elementwise_max( - x=global_norm_var, y=max_global_norm)) + + # only when global_norm_var > max_global_norm, grad need clip + need_clip = False + if global_norm_var > max_global_norm: + need_clip = True + + if need_clip: + clip_var = layers.elementwise_div( + x=max_global_norm, y=global_norm_var) for p, g in params_grads: if g is None: continue @@ -509,10 +510,14 @@ class ClipGradByGlobalNorm(ClipGradBase): params_and_grads.append((p, g)) continue # TODO(wangxi): use inplace elementwise_mul - clip_input = (clip_var.astype('float16') - if g.dtype == core.VarDesc.VarType.FP16 else clip_var) - new_grad = layers.elementwise_mul(x=g, y=clip_input) - params_and_grads.append((p, new_grad)) + if need_clip: + clip_input = (clip_var.astype('float16') + if g.dtype == core.VarDesc.VarType.FP16 else + clip_var) + new_grad = layers.elementwise_mul(x=g, y=clip_input) + params_and_grads.append((p, new_grad)) + else: + params_and_grads.append((p, g)) return params_and_grads -- GitLab