diff --git a/paddle/phi/kernels/selected_rows/shape_kernel.cc b/paddle/phi/kernels/selected_rows/shape_kernel.cc index 11971f24f3987cb787af1c53f8ee93d74f5d9e0e..f44a6a8dfafc508fc2c17baf553399cb1d081b19 100644 --- a/paddle/phi/kernels/selected_rows/shape_kernel.cc +++ b/paddle/phi/kernels/selected_rows/shape_kernel.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/shape_kernel.h" @@ -71,6 +72,23 @@ PD_REGISTER_KERNEL(shape_sr, } #endif +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(shape_sr, + XPU, + ALL_LAYOUT, + phi::sr::ShapeKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); +} +#endif + #ifdef PADDLE_WITH_CUSTOM_DEVICE PD_REGISTER_KERNEL(shape_sr, Custom,