未验证 提交 6cd095fc 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

add bf16 for some ops in static mode (#51582)

上级 8abc5333
......@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
......@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -278,7 +278,7 @@ def generate_activation_fn(op_type):
if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], op_type
x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
)
else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
......@@ -293,6 +293,7 @@ def generate_activation_fn(op_type):
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
......
......@@ -49,7 +49,7 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
......@@ -133,9 +133,15 @@ def update_loss_scaling(
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(
e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
)
if e.dtype == core.VarDesc.VarType.FP16:
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
......
......@@ -294,6 +294,7 @@ def generate_activation_fn(op_type):
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册