diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 2d9b06c6050b33d1b07a67609e69a4f95342b51d..8cc09d1a9be1363374848b00155d898b950b2680 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -248,11 +248,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT16, phi::DataType::UINT8, phi::DataType::BOOL, - phi::DataType::FLOAT64, phi::DataType::FLOAT32, - phi::DataType::FLOAT16, - phi::DataType::COMPLEX64, - phi::DataType::COMPLEX128})}, + phi::DataType::FLOAT16})}, {"flatten2_grad", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/selected_rows/full_kernel.cc b/paddle/phi/kernels/selected_rows/full_kernel.cc index 14987bc61b1593032ddd0c3c59c6e459656969e5..a492c1c304bd2f631f9c107c659eebdaefa82e6b 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.cc +++ b/paddle/phi/kernels/selected_rows/full_kernel.cc @@ -70,3 +70,17 @@ PD_REGISTER_KERNEL(full_sr, phi::dtype::complex, phi::dtype::complex) {} #endif + +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(full_sr, + XPU, + ALL_LAYOUT, + phi::sr::FullKernel, + float, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 44c5842210b71b5f1705c9b16a11f8800170053d..ae080d0dad07253f37adc2021c3c9606020dda3d 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" @@ -59,8 +60,19 @@ void FullKernel(const Context& dev_ctx, const Scalar& val, DataType dtype, DenseTensor* out) { + using XPUInTDType = typename XPUTypeTrait::Type; out->Resize(phi::make_ddim(shape.GetData())); - FullValueXPU(dev_ctx, out, val.to()); + int numel = out->numel(); + dev_ctx.template Alloc(out); + auto value = val.to(); + auto out_data = reinterpret_cast(out->data()); + if (numel > 0) { + int r = xpu::constant(dev_ctx.x_context(), + out_data, + out->numel(), + static_cast(value)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + } } template @@ -103,16 +115,11 @@ void FullLikeKernel(const Context& dev_ctx, phi::errors::InvalidArgument("The filled value is Inf.")); auto out_data = reinterpret_cast(out->data()); - int ret = xpu::constant(dev_ctx.x_context(), - out_data, - out->numel(), - static_cast(value)); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - phi::errors::External("XPU CONSTANT API return wrong value[%d %s].", - ret, - XPUAPIErrorMsg[ret])); + int r = xpu::constant(dev_ctx.x_context(), + out_data, + out->numel(), + static_cast(value)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); } } // namespace phi @@ -122,24 +129,23 @@ PD_REGISTER_KERNEL(full, ALL_LAYOUT, phi::FullKernel, float, - double, uint8_t, int16_t, int, int64_t, bool, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::float16) {} PD_REGISTER_KERNEL(full_like, XPU, ALL_LAYOUT, phi::FullLikeKernel, float, + uint8_t, + int16_t, int, int64_t, + bool, phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); }