diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index 54279fca6e429e3da9b7fc3a5726f27ab78f4cd1..73d963f606e3f1a26125dce66b0298c213c8bdff 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -38,6 +38,7 @@ PD_REGISTER_KERNEL(flatten_grad, CPU, ALL_LAYOUT, phi::FlattenGradKernel, + phi::dtype::bfloat16, float, double, uint8_t, @@ -52,6 +53,7 @@ PD_REGISTER_KERNEL(flatten_grad, phi::FlattenGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t, diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index dd000896073c70fedf82a501e84200837e4af4d1..006d3438288c1e3c6fa02069bf2fec99ccdf6469 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(flatten, ALL_LAYOUT, phi::FlattenKernel, float, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -66,6 +67,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ALL_LAYOUT, phi::FlattenWithXShape, float, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(flatten, phi::FlattenKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -93,6 +96,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, phi::FlattenWithXShape, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t,