未验证 提交 6cd095fc 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

add bf16 for some ops in static mode (#51582)

上级 8abc5333
...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad, ...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel, phi::MatmulWithFlattenGradKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad, PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten, ...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel, phi::MatmulWithFlattenKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -278,7 +278,7 @@ def generate_activation_fn(op_type): ...@@ -278,7 +278,7 @@ def generate_activation_fn(op_type):
if op_type not in ["abs", "exp", "square"]: if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
) )
else: else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64) # abs exp square ops support dtype(int32, int64, float16, float32, float64)
...@@ -293,6 +293,7 @@ def generate_activation_fn(op_type): ...@@ -293,6 +293,7 @@ def generate_activation_fn(op_type):
'float64', 'float64',
'complex64', 'complex64',
'complex128', 'complex128',
'uint16',
], ],
op_type, op_type,
) )
......
...@@ -49,7 +49,7 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -49,7 +49,7 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
check_variable_and_dtype( check_variable_and_dtype(
e, e,
"x", "x",
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale', 'check_finite_and_unscale',
) )
...@@ -133,9 +133,15 @@ def update_loss_scaling( ...@@ -133,9 +133,15 @@ def update_loss_scaling(
check_type(x, 'x', (tuple, list), 'update_loss_scaling') check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x: for e in x:
check_variable_and_dtype( check_variable_and_dtype(
e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling' e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
) )
if e.dtype == core.VarDesc.VarType.FP16: if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert ( assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32 prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
......
...@@ -294,6 +294,7 @@ def generate_activation_fn(op_type): ...@@ -294,6 +294,7 @@ def generate_activation_fn(op_type):
'float64', 'float64',
'complex64', 'complex64',
'complex128', 'complex128',
'uint16',
], ],
op_type, op_type,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册