未验证 提交 8cbc75ca 编写于 作者: Y Yiqun Liu 提交者: GitHub

Cherry-pick the support of bf16 of grad_clip, in #51285. (#52816)

上级 3869a3b4
......@@ -420,6 +420,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value
_allow_pure_bf16_global_norm_clip_flag = False
def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
......@@ -584,6 +598,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
......@@ -598,18 +613,27 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad
)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32)
== 0
):
......@@ -620,7 +644,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
if (
sum_square_list_fp32
or sum_square_list
......@@ -631,8 +655,20 @@ class ClipGradByGlobalNorm(ClipGradBase):
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
......@@ -641,23 +677,24 @@ class ClipGradByGlobalNorm(ClipGradBase):
)
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var_other_dtype = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = (
layers.sums(global_norm_var)
paddle.add_n(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
global_norm_var = paddle.sqrt(x=global_norm_var)
max_global_norm = paddle.full(
shape=[1],
dtype=global_norm_var.dtype,
fill_value=self.clip_norm,
)
scale_var = layers.elementwise_div(
scale_var = paddle.divide(
x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var
),
y=paddle.maximum(x=max_global_norm, y=global_norm_var),
)
param_new_grad_name_dict = dict()
for p, g in params_grads:
......@@ -671,9 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
scale_var.astype(new_g.dtype)
if scale_var.dtype != new_g.dtype
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册