From c491b361036dc776f74eb9cbb718561070be036d Mon Sep 17 00:00:00 2001 From: ZhouMengLei1999 <33919397+ZhouMengLei1999@users.noreply.github.com> Date: Thu, 23 Mar 2023 14:04:33 +0800 Subject: [PATCH] [XPU] support lod_reset (#51967) --- paddle/fluid/operators/lod_reset_op.cc | 10 ++++++++++ paddle/fluid/operators/lod_reset_op.h | 12 ++++++++++++ paddle/phi/backends/xpu/xpu2_op_list.cc | 5 +++++ 3 files changed, 27 insertions(+) diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 502afbf0c77..3c22660f8e4 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -253,6 +253,16 @@ REGISTER_OP_CPU_KERNEL( ops::LoDResetKernel<paddle::platform::CPUPlace, double>, ops::LoDResetKernel<paddle::platform::CPUPlace, int>, ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>); + +#ifdef PADDLE_WITH_XPU +REGISTER_OP_XPU_KERNEL( + lod_reset, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>, + ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>); +#endif + REGISTER_OP_CPU_KERNEL( lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>, diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index a07eecfef4d..3b31933174d 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -17,9 +17,21 @@ limitations under the License. */ #include <algorithm> #include <vector> +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" +#endif + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/string_array.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#endif + namespace paddle { namespace operators { diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index cf0a9b65645..b21cb6b55a3 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -870,6 +870,11 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"lod_reset", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, }; return s_xpu2_kernels; -- GitLab