From f55eb06f1f326f3928019dbe369839397441733f Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Mon, 5 Jun 2023 14:17:54 +0800 Subject: [PATCH] [XPU] fix unittest of shape op. (#54323) --- .../phi/kernels/selected_rows/shape_kernel.cc | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/paddle/phi/kernels/selected_rows/shape_kernel.cc b/paddle/phi/kernels/selected_rows/shape_kernel.cc index 11971f24f39..f44a6a8dfaf 100644 --- a/paddle/phi/kernels/selected_rows/shape_kernel.cc +++ b/paddle/phi/kernels/selected_rows/shape_kernel.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/shape_kernel.h" @@ -71,6 +72,23 @@ PD_REGISTER_KERNEL(shape_sr, } #endif +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(shape_sr, + XPU, + ALL_LAYOUT, + phi::sr::ShapeKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} +#endif + #ifdef PADDLE_WITH_CUSTOM_DEVICE PD_REGISTER_KERNEL(shape_sr, Custom, -- GitLab