From ff818c773792d8b367f5291457e3612acbdf118a Mon Sep 17 00:00:00 2001 From: TTerror Date: Fri, 15 Apr 2022 15:11:57 +0800 Subject: [PATCH] add fp16 for masked_select on kunlun, *test=kunlun (#41215) --- paddle/fluid/operators/masked_select_op_xpu.cc | 8 ++++++-- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/masked_select_op_xpu.cc b/paddle/fluid/operators/masked_select_op_xpu.cc index 00248165a5..3845046825 100644 --- a/paddle/fluid/operators/masked_select_op_xpu.cc +++ b/paddle/fluid/operators/masked_select_op_xpu.cc @@ -19,13 +19,15 @@ namespace operators { template class MaskedSelectXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { auto input = context.Input("X"); auto mask = context.Input("Mask"); auto out = context.Output("Y"); auto* mask_data = mask->data(); - auto* input_data = input->data(); + auto* input_data = reinterpret_cast(input->data()); auto input_dim = input->dims(); auto mask_dim = mask->dims(); PADDLE_ENFORCE_EQ( @@ -51,7 +53,8 @@ class MaskedSelectXPUKernel : public framework::OpKernel { framework::DDim out_dim{out_size_cpu}; out->Resize(out_dim); - auto out_data = out->mutable_data(context.GetPlace()); + auto out_data = + reinterpret_cast(out->mutable_data(context.GetPlace())); auto input_shape = phi::vectorize(input_dim); auto mask_shape = phi::vectorize(mask_dim); @@ -69,6 +72,7 @@ class MaskedSelectXPUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel, + ops::MaskedSelectXPUKernel, ops::MaskedSelectXPUKernel, ops::MaskedSelectXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 750a389940..6f4826bd8c 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -243,6 +243,7 @@ XPUOpMap& get_kl2_ops() { {"masked_select", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_v2_grad", -- GitLab