未验证 提交 65f7fa0d 编写于 作者: Z zhangbo9674 提交者: GitHub

Refine clip_by_global_norm (#38209)

* refine clip

* delete unused code

* refine logic for clip
上级 e8e47581
......@@ -479,29 +479,30 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.concat(sum_square_list_fp16)
global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16)
global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.concat(sum_square_list_fp32)
global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32)
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
global_norm_var_fp64 = layers.concat(sum_square_list)
global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64)
global_norm_var_fp64 = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_fp64)
global_norm_var = layers.concat(global_norm_var)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
# only when global_norm_var > max_global_norm, grad need clip
need_clip = False
if global_norm_var > max_global_norm:
need_clip = True
if need_clip:
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
x=max_global_norm, y=global_norm_var)
for p, g in params_grads:
if g is None:
continue
......@@ -509,10 +510,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
if need_clip:
clip_input = (clip_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else clip_var)
if g.dtype == core.VarDesc.VarType.FP16 else
clip_var)
new_grad = layers.elementwise_mul(x=g, y=clip_input)
params_and_grads.append((p, new_grad))
else:
params_and_grads.append((p, g))
return params_and_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册