From e8530a353a8717d28616e442d9e2cb880bf0172e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 20 Mar 2023 15:35:11 +0800 Subject: [PATCH] Register custom kernel for some all_bakcend kernel (#51639) * register some custom kernel * fix bug --- paddle/phi/kernels/cpu/numel_kernel.cc | 18 +++++++++++ paddle/phi/kernels/flatten_grad_kernel.cc | 15 +++++++++ paddle/phi/kernels/flatten_kernel.cc | 28 ++++++++++++++++ paddle/phi/kernels/reshape_grad_kernel.cc | 13 ++++++++ paddle/phi/kernels/reshape_kernel.cc | 13 ++++++++ .../phi/kernels/selected_rows/shape_kernel.cc | 32 +++++++++++++++++-- paddle/phi/kernels/shape_kernel.cc | 21 ++++++++++++ 7 files changed, 138 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/cpu/numel_kernel.cc b/paddle/phi/kernels/cpu/numel_kernel.cc index 047bf99d2cc..07030a72cd5 100644 --- a/paddle/phi/kernels/cpu/numel_kernel.cc +++ b/paddle/phi/kernels/cpu/numel_kernel.cc @@ -33,3 +33,21 @@ PD_REGISTER_KERNEL(numel, bool) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(numel, + Custom, + ALL_LAYOUT, + phi::NumelKernel, + uint8_t, + int16_t, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double, + bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT64); +} +#endif diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index 031f4afe98b..42d137ba4f4 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -74,3 +74,18 @@ PD_REGISTER_KERNEL(flatten_grad, int64_t) {} #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(flatten_grad, + Custom, + ALL_LAYOUT, + phi::FlattenGradKernel, + float, + phi::dtype::float16, + double, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} +#endif diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index 1706778237f..939e2706136 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -128,3 +128,31 @@ PD_REGISTER_KERNEL(flatten, int, int64_t) {} #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(flatten_infer, + Custom, + ALL_LAYOUT, + phi::FlattenInferKernel, + float, + phi::dtype::float16, + double, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(flatten, + Custom, + ALL_LAYOUT, + phi::FlattenKernel, + float, + phi::dtype::float16, + double, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} +#endif diff --git a/paddle/phi/kernels/reshape_grad_kernel.cc b/paddle/phi/kernels/reshape_grad_kernel.cc index ffd616054c0..35f19c178a6 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.cc +++ b/paddle/phi/kernels/reshape_grad_kernel.cc @@ -97,3 +97,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_double_grad, phi::ReshapeDoubleGradKernel, ALL_DTYPE) {} #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_GENERAL_KERNEL(reshape_grad, + Custom, + ALL_LAYOUT, + phi::ReshapeGradKernel, + ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(reshape_double_grad, + Custom, + ALL_LAYOUT, + phi::ReshapeDoubleGradKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/phi/kernels/reshape_kernel.cc b/paddle/phi/kernels/reshape_kernel.cc index e34847b688c..d6e6e25c053 100644 --- a/paddle/phi/kernels/reshape_kernel.cc +++ b/paddle/phi/kernels/reshape_kernel.cc @@ -114,3 +114,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_infer, PD_REGISTER_GENERAL_KERNEL( reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_GENERAL_KERNEL(reshape_infer, + Custom, + ALL_LAYOUT, + phi::ReshapeInferKernel, + ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(reshape, + Custom, + ALL_LAYOUT, + phi::ReshapeKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/shape_kernel.cc b/paddle/phi/kernels/selected_rows/shape_kernel.cc index 575bcc0d09f..11971f24f39 100644 --- a/paddle/phi/kernels/selected_rows/shape_kernel.cc +++ b/paddle/phi/kernels/selected_rows/shape_kernel.cc @@ -45,7 +45,11 @@ PD_REGISTER_KERNEL(shape_sr, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL(shape_sr, @@ -60,5 +64,29 @@ PD_REGISTER_KERNEL(shape_sr, float, double, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} +#endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(shape_sr, + Custom, + ALL_LAYOUT, + phi::sr::ShapeKernel, + bool, + int, + int8_t, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} #endif diff --git a/paddle/phi/kernels/shape_kernel.cc b/paddle/phi/kernels/shape_kernel.cc index 7d0778d4089..c4190a5f59b 100644 --- a/paddle/phi/kernels/shape_kernel.cc +++ b/paddle/phi/kernels/shape_kernel.cc @@ -89,3 +89,24 @@ PD_REGISTER_KERNEL(shape, kernel->OutputAt(0).SetDataType(phi::DataType::INT32); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(shape, + Custom, + ALL_LAYOUT, + phi::ShapeKernel, + bool, + int, + int8_t, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} +#endif -- GitLab