未验证 提交 8cbc75ca 编写于 作者: Y Yiqun Liu 提交者: GitHub

Cherry-pick the support of bf16 of grad_clip, in #51285. (#52816)

上级 3869a3b4
...@@ -420,6 +420,20 @@ def _allow_pure_fp16_global_norm_clip(*args): ...@@ -420,6 +420,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value return old_value
_allow_pure_bf16_global_norm_clip_flag = False
def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase): class ClipGradByGlobalNorm(ClipGradBase):
r""" r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
...@@ -584,6 +598,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -584,6 +598,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
sum_square_list_fp16 = [] sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = [] sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'): with framework.name_scope('gradient_clip'):
for p, g in params_grads: for p, g in params_grads:
...@@ -598,18 +613,27 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -598,18 +613,27 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows( merge_grad = layers.get_tensor_from_selected_rows(
merge_grad merge_grad
) )
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16: if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32: elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square) sum_square_list_fp32.append(sum_square)
else: else:
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)
# all parameters have been filterd out # all parameters have been filterd out
if ( if (
len(sum_square_list) len(sum_square_list)
+ len(sum_square_list_fp16) + len(sum_square_list_fp16)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32) + len(sum_square_list_fp32)
== 0 == 0
): ):
...@@ -620,7 +644,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -620,7 +644,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = [] global_norm_var = []
if len(sum_square_list_fp16) > 0: if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16) global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
if ( if (
sum_square_list_fp32 sum_square_list_fp32
or sum_square_list or sum_square_list
...@@ -631,8 +655,20 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -631,8 +655,20 @@ class ClipGradByGlobalNorm(ClipGradBase):
) )
else: else:
global_norm_var.append(global_norm_var_fp16) global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32) global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32) global_norm_var.append(global_norm_var_fp32)
else: else:
...@@ -641,23 +677,24 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -641,23 +677,24 @@ class ClipGradByGlobalNorm(ClipGradBase):
) )
if len(sum_square_list) > 0: if len(sum_square_list) > 0:
# fp64 # fp64
global_norm_var_other_dtype = layers.sums(sum_square_list) global_norm_var_other_dtype = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype) global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = ( global_norm_var = (
layers.sums(global_norm_var) paddle.add_n(global_norm_var)
if len(global_norm_var) > 1 if len(global_norm_var) > 1
else global_norm_var[0] else global_norm_var[0]
) )
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant( global_norm_var = paddle.sqrt(x=global_norm_var)
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm max_global_norm = paddle.full(
shape=[1],
dtype=global_norm_var.dtype,
fill_value=self.clip_norm,
) )
scale_var = layers.elementwise_div( scale_var = paddle.divide(
x=max_global_norm, x=max_global_norm,
y=layers.elementwise_max( y=paddle.maximum(x=max_global_norm, y=global_norm_var),
x=max_global_norm, y=global_norm_var
),
) )
param_new_grad_name_dict = dict() param_new_grad_name_dict = dict()
for p, g in params_grads: for p, g in params_grads:
...@@ -671,9 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -671,9 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
new_g = _cast_to_mp_type_if_enabled(g) new_g = _cast_to_mp_type_if_enabled(g)
# inplace # inplace
scale_input = ( scale_input = (
scale_var.astype('float16') scale_var.astype(new_g.dtype)
if new_g.dtype == core.VarDesc.VarType.FP16 if scale_var.dtype != new_g.dtype
and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var else scale_var
) )
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册