diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 502afbf0c77f10b440b023680c3522ce8479a9e9..3c22660f8e4f3758cbbd595a3355d8fd0fdd3899 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -253,6 +253,16 @@ REGISTER_OP_CPU_KERNEL( ops::LoDResetKernel<paddle::platform::CPUPlace, double>, ops::LoDResetKernel<paddle::platform::CPUPlace, int>, ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>); + +#ifdef PADDLE_WITH_XPU +REGISTER_OP_XPU_KERNEL( + lod_reset, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>); +#endif + REGISTER_OP_CPU_KERNEL( lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>, diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index a07eecfef4d5deaa08c61b8cc602d4d269718559..3b31933174d002a36501d83337b4f1a83d7a0d2f 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -17,9 +17,21 @@ limitations under the License. */ #include <algorithm> #include <vector> +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" +#endif + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/string_array.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#endif + namespace paddle { namespace operators { diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index cf0a9b656458f9dd8161faac06e8e04df6d3daf7..b21cb6b55a338719fa3aac0bf6549a62f897463f 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -870,6 +870,11 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"lod_reset", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, }; return s_xpu2_kernels;