未验证 提交 0fdb3ced 编写于 作者: G Guoxia Wang 提交者: GitHub

add bf16 dtype for flatten kernel (#43264)

上级 eac125f9
...@@ -38,6 +38,7 @@ PD_REGISTER_KERNEL(flatten_grad, ...@@ -38,6 +38,7 @@ PD_REGISTER_KERNEL(flatten_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenGradKernel, phi::FlattenGradKernel,
phi::dtype::bfloat16,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -52,6 +53,7 @@ PD_REGISTER_KERNEL(flatten_grad, ...@@ -52,6 +53,7 @@ PD_REGISTER_KERNEL(flatten_grad,
phi::FlattenGradKernel, phi::FlattenGradKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
......
...@@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(flatten, ...@@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(flatten,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenKernel, phi::FlattenKernel,
float, float,
phi::dtype::bfloat16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -66,6 +67,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ...@@ -66,6 +67,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlattenWithXShape, phi::FlattenWithXShape,
float, float,
phi::dtype::bfloat16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(flatten, ...@@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(flatten,
phi::FlattenKernel, phi::FlattenKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
...@@ -93,6 +96,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ...@@ -93,6 +96,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape,
phi::FlattenWithXShape, phi::FlattenWithXShape,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册