未验证 提交 c491b361 编写于 作者: Z ZhouMengLei1999 提交者: GitHub

[XPU] support lod_reset (#51967)

上级 5da1a27b
......@@ -253,6 +253,16 @@ REGISTER_OP_CPU_KERNEL(
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
REGISTER_OP_CPU_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
......
......@@ -17,9 +17,21 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#endif
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#endif
namespace paddle {
namespace operators {
......
......@@ -870,6 +870,11 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fused_feedforward_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"lod_reset",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
};
return s_xpu2_kernels;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册