From 0fdb3ced4574487c3fbed7f325aa7b89f71af28b Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 7 Jun 2022 14:23:18 +0800 Subject: [PATCH] add bf16 dtype for flatten kernel (#43264) --- paddle/phi/kernels/flatten_grad_kernel.cc | 2 ++ paddle/phi/kernels/flatten_kernel.cc | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index 54279fca6e4..73d963f606e 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -38,6 +38,7 @@ PD_REGISTER_KERNEL(flatten_grad, CPU, ALL_LAYOUT, phi::FlattenGradKernel, + phi::dtype::bfloat16, float, double, uint8_t, @@ -52,6 +53,7 @@ PD_REGISTER_KERNEL(flatten_grad, phi::FlattenGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t, diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index dd000896073..006d3438288 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(flatten, ALL_LAYOUT, phi::FlattenKernel, float, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -66,6 +67,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, ALL_LAYOUT, phi::FlattenWithXShape, float, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(flatten, phi::FlattenKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t, @@ -93,6 +96,7 @@ PD_REGISTER_KERNEL(flatten_with_xshape, phi::FlattenWithXShape, float, phi::dtype::float16, + phi::dtype::bfloat16, double, uint8_t, int8_t, -- GitLab