From 9cbae54c628837376f8549042cf3658b16fe0507 Mon Sep 17 00:00:00 2001 From: feifei-111 Date: Wed, 31 Aug 2022 16:32:47 +0800 Subject: [PATCH] [phi] Migrate masked_select XPU kernel to phi (#45575) * test=kunlun * test=kunlun --- .../fluid/operators/masked_select_op_xpu.cc | 88 ------------------- .../phi/kernels/xpu/masked_select_kernel.cc | 86 ++++++++++++++++++ 2 files changed, 86 insertions(+), 88 deletions(-) delete mode 100644 paddle/fluid/operators/masked_select_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/masked_select_kernel.cc diff --git a/paddle/fluid/operators/masked_select_op_xpu.cc b/paddle/fluid/operators/masked_select_op_xpu.cc deleted file mode 100644 index 7793371bc1..0000000000 --- a/paddle/fluid/operators/masked_select_op_xpu.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -class MaskedSelectXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input("X"); - auto mask = context.Input("Mask"); - auto out = context.Output("Y"); - auto* mask_data = mask->data(); - auto* input_data = reinterpret_cast(input->data()); - auto input_dim = input->dims(); - auto mask_dim = mask->dims(); - PADDLE_ENFORCE_EQ( - input_dim, - mask_dim, - platform::errors::InvalidArgument( - "The dim size of input and mask in OP(masked_selected) " - "must be equal, but got input dim:(%ld), mask dim: " - "(%ld). Please check input " - "value.", - input_dim, - mask_dim)); - auto& dev_ctx = - context.template device_context(); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int* out_size = RAII_GUARD.alloc_l3_or_gm(1); - int out_size_cpu; - - PADDLE_ENFORCE_XDNN_SUCCESS( - xpu::nonzero_count( - dev_ctx.x_context(), mask_data, out_size, mask->numel()), - "nonzero_count "); - memory::Copy(platform::CPUPlace(), - static_cast(&out_size_cpu), - mask->place(), - static_cast(out_size), - sizeof(int32_t)); - - framework::DDim out_dim{out_size_cpu}; - out->Resize(out_dim); - auto out_data = - reinterpret_cast(out->mutable_data(context.GetPlace())); - - auto input_shape = phi::vectorize(input_dim); - auto mask_shape = phi::vectorize(mask_dim); - - PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), - input_data, - mask_data, - out_data, - input_shape, - mask_shape, - out_size_cpu), - "masked_select"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_XPU_KERNEL(masked_select, - ops::MaskedSelectXPUKernel, - ops::MaskedSelectXPUKernel, - ops::MaskedSelectXPUKernel, - ops::MaskedSelectXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/masked_select_kernel.cc b/paddle/phi/kernels/xpu/masked_select_kernel.cc new file mode 100644 index 0000000000..43b8d2cba2 --- /dev/null +++ b/paddle/phi/kernels/xpu/masked_select_kernel.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/masked_select_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +template +void MaskedSelectKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + auto input = &x; + auto* mask_data = mask.data(); + auto* input_data = reinterpret_cast(input->data()); + auto input_dim = input->dims(); + auto mask_dim = mask.dims(); + PADDLE_ENFORCE_EQ(input_dim, + mask_dim, + phi::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, + mask_dim)); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* out_size = RAII_GUARD.alloc_l3_or_gm(1); + int out_size_cpu; + + PADDLE_ENFORCE_XDNN_SUCCESS( + xpu::nonzero_count( + dev_ctx.x_context(), mask_data, out_size, mask.numel()), + "nonzero_count "); + paddle::memory::Copy(phi::CPUPlace(), + static_cast(&out_size_cpu), + mask.place(), + static_cast(out_size), + sizeof(int32_t)); + + DDim out_dim{out_size_cpu}; + out->Resize(out_dim); + auto out_data = reinterpret_cast(dev_ctx.template Alloc(out)); + + auto input_shape = vectorize(input_dim); + auto mask_shape = vectorize(mask_dim); + + PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), + input_data, + mask_data, + out_data, + input_shape, + mask_shape, + out_size_cpu), + "masked_select"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(masked_select, + XPU, + ALL_LAYOUT, + phi::MaskedSelectKernel, + float, + phi::dtype::float16, + int, + int64_t) { + kernel->InputAt(1).SetDataType(phi::DataType::BOOL); +} -- GitLab