From c491b361036dc776f74eb9cbb718561070be036d Mon Sep 17 00:00:00 2001
From: ZhouMengLei1999 <33919397+ZhouMengLei1999@users.noreply.github.com>
Date: Thu, 23 Mar 2023 14:04:33 +0800
Subject: [PATCH] [XPU] support lod_reset (#51967)

---
 paddle/fluid/operators/lod_reset_op.cc  | 10 ++++++++++
 paddle/fluid/operators/lod_reset_op.h   | 12 ++++++++++++
 paddle/phi/backends/xpu/xpu2_op_list.cc |  5 +++++
 3 files changed, 27 insertions(+)

diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc
index 502afbf0c77..3c22660f8e4 100644
--- a/paddle/fluid/operators/lod_reset_op.cc
+++ b/paddle/fluid/operators/lod_reset_op.cc
@@ -253,6 +253,16 @@ REGISTER_OP_CPU_KERNEL(
     ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
     ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
     ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
+
+#ifdef PADDLE_WITH_XPU
+REGISTER_OP_XPU_KERNEL(
+    lod_reset,
+    ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>,
+    ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>,
+    ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>,
+    ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>);
+#endif
+
 REGISTER_OP_CPU_KERNEL(
     lod_reset_grad,
     ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h
index a07eecfef4d..3b31933174d 100644
--- a/paddle/fluid/operators/lod_reset_op.h
+++ b/paddle/fluid/operators/lod_reset_op.h
@@ -17,9 +17,21 @@ limitations under the License. */
 #include <algorithm>
 #include <vector>
 
+#ifdef PADDLE_WITH_XPU
+#include "paddle/fluid/framework/convert_utils.h"
+#include "paddle/fluid/framework/data_type.h"
+#include "paddle/fluid/framework/data_type_transform.h"
+#endif
+
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/op_registry.h"
 
+#ifdef PADDLE_WITH_XPU
+#include "paddle/fluid/framework/string_array.h"
+#include "paddle/fluid/framework/tensor_util.h"
+#include "paddle/fluid/platform/device_context.h"
+#endif
+
 namespace paddle {
 namespace operators {
 
diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc
index cf0a9b65645..b21cb6b55a3 100644
--- a/paddle/phi/backends/xpu/xpu2_op_list.cc
+++ b/paddle/phi/backends/xpu/xpu2_op_list.cc
@@ -870,6 +870,11 @@ XPUOpMap& get_kl2_ops() {
        XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
       {"fused_feedforward_grad",
        XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
+      {"lod_reset",
+       XPUKernelSet({phi::DataType::FLOAT32,
+                     phi::DataType::FLOAT64,
+                     phi::DataType::INT32,
+                     phi::DataType::INT64})},
   };
 
   return s_xpu2_kernels;
-- 
GitLab