未验证 提交 65f7fa0d 编写于 作者: Z zhangbo9674 提交者: GitHub

Refine clip_by_global_norm (#38209)

* refine clip

* delete unused code

* refine logic for clip
上级 e8e47581
...@@ -479,29 +479,30 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -479,29 +479,30 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = [] global_norm_var = []
if len(sum_square_list_fp16) > 0: if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.concat(sum_square_list_fp16) global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.concat(sum_square_list_fp32) global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32) global_norm_var.append(global_norm_var_fp32)
else: else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0: if len(sum_square_list) > 0:
global_norm_var_fp64 = layers.concat(sum_square_list) global_norm_var_fp64 = paddle.add_n(sum_square_list)
global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64)
global_norm_var.append(global_norm_var_fp64) global_norm_var.append(global_norm_var_fp64)
global_norm_var = layers.concat(global_norm_var) global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm, # only when global_norm_var > max_global_norm, grad need clip
y=layers.elementwise_max( need_clip = False
x=global_norm_var, y=max_global_norm)) if global_norm_var > max_global_norm:
need_clip = True
if need_clip:
clip_var = layers.elementwise_div(
x=max_global_norm, y=global_norm_var)
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
...@@ -509,10 +510,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -509,10 +510,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads.append((p, g)) params_and_grads.append((p, g))
continue continue
# TODO(wangxi): use inplace elementwise_mul # TODO(wangxi): use inplace elementwise_mul
clip_input = (clip_var.astype('float16') if need_clip:
if g.dtype == core.VarDesc.VarType.FP16 else clip_var) clip_input = (clip_var.astype('float16')
new_grad = layers.elementwise_mul(x=g, y=clip_input) if g.dtype == core.VarDesc.VarType.FP16 else
params_and_grads.append((p, new_grad)) clip_var)
new_grad = layers.elementwise_mul(x=g, y=clip_input)
params_and_grads.append((p, new_grad))
else:
params_and_grads.append((p, g))
return params_and_grads return params_and_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册