diff --git a/paddle/fluid/operators/masked_select_op_xpu.cc b/paddle/fluid/operators/masked_select_op_xpu.cc index 00248165a511de590743eb57142a8e660968e087..3845046825355051d1caf06e7070230a9281125e 100644 --- a/paddle/fluid/operators/masked_select_op_xpu.cc +++ b/paddle/fluid/operators/masked_select_op_xpu.cc @@ -19,13 +19,15 @@ namespace operators { template class MaskedSelectXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { auto input = context.Input("X"); auto mask = context.Input("Mask"); auto out = context.Output("Y"); auto* mask_data = mask->data(); - auto* input_data = input->data(); + auto* input_data = reinterpret_cast(input->data()); auto input_dim = input->dims(); auto mask_dim = mask->dims(); PADDLE_ENFORCE_EQ( @@ -51,7 +53,8 @@ class MaskedSelectXPUKernel : public framework::OpKernel { framework::DDim out_dim{out_size_cpu}; out->Resize(out_dim); - auto out_data = out->mutable_data(context.GetPlace()); + auto out_data = + reinterpret_cast(out->mutable_data(context.GetPlace())); auto input_shape = phi::vectorize(input_dim); auto mask_shape = phi::vectorize(mask_dim); @@ -69,6 +72,7 @@ class MaskedSelectXPUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel, + ops::MaskedSelectXPUKernel, ops::MaskedSelectXPUKernel, ops::MaskedSelectXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 750a389940c65458602de2209b4f5a2ef2b30156..6f4826bd8c39a4b4669a595352138df67c6f5149 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -243,6 +243,7 @@ XPUOpMap& get_kl2_ops() { {"masked_select", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_v2_grad",