diff --git a/paddle/fluid/operators/where_index_op.cc b/paddle/fluid/operators/where_index_op.cc index 2bffeb500ce50e3bc5a3d72a085da826d06e849d..733d0f7af92d727bd3eff31a87e7e88b3d073829 100644 --- a/paddle/fluid/operators/where_index_op.cc +++ b/paddle/fluid/operators/where_index_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/where_index_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,16 +24,6 @@ class WhereIndexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "where"); - PADDLE_ENFORCE_GE( - ctx->GetInputDim("Condition").size(), 1UL, - platform::errors::InvalidArgument( - "Input(Condition) should have number of dimension at least 1")); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "where"); - ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -53,11 +46,10 @@ class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp, - ops::WhereIndexOpMaker); -REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel, - ops::CPUWhereIndexKernel, - ops::CPUWhereIndexKernel, - ops::CPUWhereIndexKernel, - ops::CPUWhereIndexKernel, - ops::CPUWhereIndexKernel); +DECLARE_INFER_SHAPE_FUNCTOR(where_index, WhereIndexInferShapeFunctor, + PD_INFER_META(phi::WhereIndexInferMeta)); +REGISTER_OPERATOR( + where_index, ops::WhereIndexOp, ops::WhereIndexOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + WhereIndexInferShapeFunctor); diff --git a/paddle/fluid/operators/where_index_op.cu b/paddle/fluid/operators/where_index_op.cu deleted file mode 100644 index c594e478aa0f3cb36b2bb63bdd1dc22e87613bf0..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/where_index_op.cu +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/where_index_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/core/ddim.h" - -namespace paddle { -namespace operators { - -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -template -__global__ void GetTrueNum(const T *cond_data, const int64_t numel, - int64_t *true_num_array) { - const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { - true_num_array[idx] = - static_cast(static_cast(cond_data[idx])); - } -} - -template -__global__ void SetTrueIndex(int64_t *out_ptr, const T *cond_data, - const int64_t numel, const int64_t *stride_array, - const int64_t rank, - const int64_t *true_num_array) { - const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { - // true_num_array is calculated by cub::InclusiveSum, - // cause the first element of true_num_array is 1, - // so we need substract 1 to get true index. - const int64_t true_index = true_num_array[idx] - 1; - if (static_cast(cond_data[idx])) { - int64_t rank_index = idx; - for (int j = 0; j < rank; j++) { - const int64_t out_index = rank_index / stride_array[j]; - out_ptr[true_index * rank + j] = out_index; - rank_index -= out_index * stride_array[j]; - } - } - } -} - -template -class CUDAWhereIndexKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *condition = context.Input("Condition"); - auto *out = context.Output("Out"); - auto &dev_ctx = context.template device_context(); - - const T *cond_data = condition->data(); - const int64_t numel = condition->numel(); - auto dims = condition->dims(); - const int rank = dims.size(); - - auto d_array_mem = memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t)); - auto h_array_mem = - memory::Alloc(platform::CPUPlace(), (rank + 1) * sizeof(int64_t)); - - // "stride_array" is an array and len(stride_array)==rank, - // each element is the stride of each dimension -- the length from i to i+1. - int64_t *h_stride_array = reinterpret_cast(h_array_mem->ptr()); - int64_t *d_stride_array = reinterpret_cast(d_array_mem->ptr()); - - // "true_num_array" is an array and len(stride_array)==numel, - // at the beginning, - // "true_num_array" will set 1 if condition[i] == true else 0, - // then it will be calculated by cub::InclusiveSum, - // so that we can get the true number before i as the out index - int64_t *d_true_num_array = d_stride_array + rank; - - // the total_true_num is the total number of condition[i] == true - int64_t *h_total_true_num = h_stride_array + rank; - - // alloce cub memory - size_t cub_size = 0; - cub::DeviceScan::InclusiveSum(nullptr, cub_size, d_true_num_array, - d_true_num_array, numel, dev_ctx.stream()); - auto cub_mem = memory::Alloc(dev_ctx, cub_size * sizeof(int64_t)); - void *cub_data = cub_mem->ptr(); - - // set d_true_num_array[i]=1 if cond_data[i]==true else 0 - const int threads = std::min(numel, static_cast(128)); - const int64_t need_grids = (numel + threads - 1) / threads; - const int grids = std::min(need_grids, static_cast(256)); - GetTrueNum<<>>(cond_data, numel, - d_true_num_array); - - // calculate the inclusive prefix sum of "true_num_array" - // to get the index of "out" tensor, - // and the total number of cond_data[i]==true. - // Example: - // condition: F T T F F F T T - // before: 0 1 1 0 0 0 1 1 - // after: 0 1 2 2 2 2 3 4 - // out: 1 2 6 7 - cub::DeviceScan::InclusiveSum(cub_data, cub_size, d_true_num_array, - d_true_num_array, numel, dev_ctx.stream()); - - // calculate each dimension's stride - h_stride_array[rank - 1] = 1; - for (int i = rank - 2; i >= 0; i--) { - h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1]; - } - memory::Copy(dev_ctx.GetPlace(), d_stride_array, platform::CPUPlace(), - h_stride_array, rank * sizeof(int64_t), dev_ctx.stream()); - - // get total ture number and set output size - // the last element of cub::InclusiveSum is the total number - memory::Copy(platform::CPUPlace(), h_total_true_num, dev_ctx.GetPlace(), - d_true_num_array + numel - 1, sizeof(int64_t), - dev_ctx.stream()); - dev_ctx.Wait(); - - int64_t true_num = *h_total_true_num; - out->Resize(phi::make_ddim({static_cast(true_num), rank})); - auto out_data = out->mutable_data(context.GetPlace()); - - if (true_num == 0) { - return; - } - - // using true_num_array and stride_array to calculate the output index - SetTrueIndex<<>>( - out_data, cond_data, numel, d_stride_array, rank, d_true_num_array); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel, - ops::CUDAWhereIndexKernel, - ops::CUDAWhereIndexKernel, - ops::CUDAWhereIndexKernel, - ops::CUDAWhereIndexKernel, - ops::CUDAWhereIndexKernel); diff --git a/paddle/fluid/operators/where_index_op.h b/paddle/fluid/operators/where_index_op.h deleted file mode 100644 index 193a2386e6bd1eb19d30b7c9e146eb8b77b8e851..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/where_index_op.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -struct WhereIndexFunctor { - WhereIndexFunctor(const T* true_index, int true_num, const T* stride, - int rank, T* out) - : true_index_(true_index), - true_num_(true_num), - stride_(stride), - rank_(rank), - out_ptr_(out) {} - - HOSTDEVICE void operator()(size_t idx) const { - T index = true_index_[idx]; - for (int j = 0; j < rank_; j++) { - out_ptr_[idx * rank_ + j] = index / stride_[j]; - index -= out_ptr_[idx * rank_ + j] * stride_[j]; - } - } - - const T* true_index_; - int true_num_; - const T* stride_; - int rank_; - T* out_ptr_; -}; - -using CPUDeviceContext = paddle::platform::CPUDeviceContext; - -template -class CPUWhereIndexKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - auto* out = context.Output("Out"); - - const T* cond_data = condition->data(); - auto numel = condition->numel(); - auto dims = condition->dims(); - const int rank = dims.size(); - - std::vector true_index; - for (auto i = 0; i < numel; i++) { - if (static_cast(cond_data[i])) { - true_index.push_back(i); - } - } - auto true_num = true_index.size(); - - out->Resize(phi::make_ddim({static_cast(true_num), rank})); - auto out_ptr = out->mutable_data(context.GetPlace()); - - if (true_num == 0) { - return; - } - - std::vector stride(rank); - stride[rank - 1] = 1; - for (int i = rank - 2; i >= 0; i--) { - stride[i] = stride[i + 1] * dims[i + 1]; - } - - auto& dev_ctx = context.template device_context(); - WhereIndexFunctor functor(true_index.data(), true_num, - stride.data(), rank, out_ptr); - platform::ForRange for_range(dev_ctx, true_num); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/where_index_op_npu.cc b/paddle/fluid/operators/where_index_op_npu.cc index 59f598d2ad6a3275158cadf32bd1bf2086a3487a..2f8744c2c0448881901656102f0ee65279f159a2 100644 --- a/paddle/fluid/operators/where_index_op_npu.cc +++ b/paddle/fluid/operators/where_index_op_npu.cc @@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/where_index_op.h" +#include + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/where_index_op_xpu.cc b/paddle/fluid/operators/where_index_op_xpu.cc deleted file mode 100644 index 3322eefd887e3d4dce5363cca842305931822a23..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/where_index_op_xpu.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/operators/where_index_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class WhereIndexXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - auto* out = context.Output("Out"); - - const T* cond_data = condition->data(); - auto numel = condition->numel(); - auto dims = condition->dims(); - const int rank = dims.size(); - - auto& dev_ctx = - context.template device_context(); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int* true_num = RAII_GUARD.alloc_l3_or_gm(1); - int true_num_cpu; - int ret = - xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); - PADDLE_ENFORCE_EQ( - ret, XPU_SUCCESS, - platform::errors::External( - "XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", - ret, XPUAPIErrorMsg[ret])); - - memory::Copy(platform::CPUPlace(), static_cast(&true_num_cpu), - context.GetPlace(), static_cast(true_num), - sizeof(int32_t)); - - out->Resize(phi::make_ddim({static_cast(true_num_cpu), rank})); - auto out_data = out->mutable_data(context.GetPlace()); - if (true_num_cpu == 0) { - return; - } - - auto condition_shape = phi::vectorize(dims); - ret = xpu::where(dev_ctx.x_context(), cond_data, out_data, condition_shape, - true_num_cpu); - PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, - platform::errors::External( - "XPU masked_select kernel return wrong value[%d %s]", - ret, XPUAPIErrorMsg[ret])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(where_index, ops::WhereIndexXPUKernel, - ops::WhereIndexXPUKernel, - ops::WhereIndexXPUKernel); -#endif diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index f7f8612632bf7ea409e70846fb1dde00d48ad21d..af035004e4bdbcb99f269d8d62bb470422d72f36 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1260,6 +1260,17 @@ void EighInferMeta(const MetaTensor& x, out_v->set_dims(input_dim); } +void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { + auto rank = condition.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument( + "Input(Condition) should have number of dimension at least 1")); + out->set_dims(phi::make_ddim({-1, rank})); + out->set_dtype(DataType::INT64); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 08d05db1e508657091eb9772fe1d76ae39a06711..bd79bf9d6ed1daeb16eb868a5ae62fb605bf258e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -182,4 +182,6 @@ void EighInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); +void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/where_index_kernel.cc b/paddle/phi/kernels/cpu/where_index_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..da6eff74011eaaf5ac5c6d89091743be9a866a5b --- /dev/null +++ b/paddle/phi/kernels/cpu/where_index_kernel.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_index_kernel.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct WhereIndexFunctor { + WhereIndexFunctor( + const T* true_index, int true_num, const T* stride, int rank, T* out) + : true_index_(true_index), + true_num_(true_num), + stride_(stride), + rank_(rank), + out_ptr_(out) {} + + HOSTDEVICE void operator()(size_t idx) const { + T index = true_index_[idx]; + for (int j = 0; j < rank_; j++) { + out_ptr_[idx * rank_ + j] = index / stride_[j]; + index -= out_ptr_[idx * rank_ + j] * stride_[j]; + } + } + + const T* true_index_; + int true_num_; + const T* stride_; + int rank_; + T* out_ptr_; +}; + +template +void WhereIndexKernel(const Context& dev_ctx, + const DenseTensor& condition, + DenseTensor* out) { + const T* cond_data = condition.data(); + auto numel = condition.numel(); + auto dims = condition.dims(); + const int rank = dims.size(); + + std::vector true_index; + for (auto i = 0; i < numel; i++) { + if (static_cast(cond_data[i])) { + true_index.push_back(i); + } + } + auto true_num = true_index.size(); + out->Resize(phi::make_ddim({static_cast(true_num), rank})); + auto* out_ptr = dev_ctx.template Alloc(out); + + if (true_num == 0) { + return; + } + + std::vector stride(rank); + stride[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + stride[i] = stride[i + 1] * dims[i + 1]; + } + + WhereIndexFunctor functor( + true_index.data(), true_num, stride.data(), rank, out_ptr); + phi::funcs::ForRange for_range(dev_ctx, true_num); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(where_index, + CPU, + ALL_LAYOUT, + phi::WhereIndexKernel, + int64_t, + int, + int16_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/where_index_kernel.cu b/paddle/phi/kernels/gpu/where_index_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..535cb812a20ea90bdb3f07b731af52c2822f0ec2 --- /dev/null +++ b/paddle/phi/kernels/gpu/where_index_kernel.cu @@ -0,0 +1,178 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/kernels/where_index_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +__global__ void GetTrueNum(const T *cond_data, + const int64_t numel, + int64_t *true_num_array) { + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { + true_num_array[idx] = + static_cast(static_cast(cond_data[idx])); + } +} + +template +__global__ void SetTrueIndex(int64_t *out_ptr, + const T *cond_data, + const int64_t numel, + const int64_t *stride_array, + const int64_t rank, + const int64_t *true_num_array) { + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { + // true_num_array is calculated by cub::InclusiveSum, + // cause the first element of true_num_array is 1, + // so we need substract 1 to get true index. + const int64_t true_index = true_num_array[idx] - 1; + if (static_cast(cond_data[idx])) { + int64_t rank_index = idx; + for (int j = 0; j < rank; j++) { + const int64_t out_index = rank_index / stride_array[j]; + out_ptr[true_index * rank + j] = out_index; + rank_index -= out_index * stride_array[j]; + } + } + } +} + +template +void WhereIndexKernel(const Context &dev_ctx, + const DenseTensor &condition, + DenseTensor *out) { + const T *cond_data = condition.data(); + const int64_t numel = condition.numel(); + auto dims = condition.dims(); + const int rank = dims.size(); + + auto d_array_mem = + paddle::memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t)); + auto h_array_mem = + paddle::memory::Alloc(phi::CPUPlace(), (rank + 1) * sizeof(int64_t)); + + // "stride_array" is an array and len(stride_array)==rank, + // each element is the stride of each dimension -- the length from i to i+1. + int64_t *h_stride_array = reinterpret_cast(h_array_mem->ptr()); + int64_t *d_stride_array = reinterpret_cast(d_array_mem->ptr()); + + // "true_num_array" is an array and len(stride_array)==numel, + // at the beginning, + // "true_num_array" will set 1 if condition[i] == true else 0, + // then it will be calculated by cub::InclusiveSum, + // so that we can get the true number before i as the out index + int64_t *d_true_num_array = d_stride_array + rank; + + // the total_true_num is the total number of condition[i] == true + int64_t *h_total_true_num = h_stride_array + rank; + + // alloce cub memory + size_t cub_size = 0; + cub::DeviceScan::InclusiveSum(nullptr, + cub_size, + d_true_num_array, + d_true_num_array, + numel, + dev_ctx.stream()); + auto cub_mem = paddle::memory::Alloc(dev_ctx, cub_size * sizeof(int64_t)); + void *cub_data = cub_mem->ptr(); + + // set d_true_num_array[i]=1 if cond_data[i]==true else 0 + const int threads = std::min(numel, static_cast(128)); + const int64_t need_grids = (numel + threads - 1) / threads; + const int grids = std::min(need_grids, static_cast(256)); + GetTrueNum<<>>( + cond_data, numel, d_true_num_array); + + // calculate the inclusive prefix sum of "true_num_array" + // to get the index of "out" tensor, + // and the total number of cond_data[i]==true. + // Example: + // condition: F T T F F F T T + // before: 0 1 1 0 0 0 1 1 + // after: 0 1 2 2 2 2 3 4 + // out: 1 2 6 7 + cub::DeviceScan::InclusiveSum(cub_data, + cub_size, + d_true_num_array, + d_true_num_array, + numel, + dev_ctx.stream()); + + // calculate each dimension's stride + h_stride_array[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1]; + } + paddle::memory::Copy(dev_ctx.GetPlace(), + d_stride_array, + phi::CPUPlace(), + h_stride_array, + rank * sizeof(int64_t), + dev_ctx.stream()); + + // get total ture number and set output size + // the last element of cub::InclusiveSum is the total number + paddle::memory::Copy(phi::CPUPlace(), + h_total_true_num, + dev_ctx.GetPlace(), + d_true_num_array + numel - 1, + sizeof(int64_t), + dev_ctx.stream()); + dev_ctx.Wait(); + + int64_t true_num = *h_total_true_num; + out->Resize(phi::make_ddim({static_cast(true_num), rank})); + auto *out_data = dev_ctx.template Alloc(out); + + if (true_num == 0) { + return; + } + + // using true_num_array and stride_array to calculate the output index + SetTrueIndex<<>>( + out_data, cond_data, numel, d_stride_array, rank, d_true_num_array); +} + +} // namespace phi + +PD_REGISTER_KERNEL(where_index, + GPU, + ALL_LAYOUT, + phi::WhereIndexKernel, + int64_t, + int, + int16_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/where_index_kernel.h b/paddle/phi/kernels/where_index_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..68b094637c8d55ba424f61f3afea642db073c13f --- /dev/null +++ b/paddle/phi/kernels/where_index_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void WhereIndexKernel(const Context& dev_ctx, + const DenseTensor& condition, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/xpu/where_index_kernel.cc b/paddle/phi/kernels/xpu/where_index_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6653e57f6eadf3e359a9281ce5737277d83b206 --- /dev/null +++ b/paddle/phi/kernels/xpu/where_index_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_index_kernel.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void WhereIndexKernel(const Context& dev_ctx, + const DenseTensor& condition, + DenseTensor* out) { + const T* cond_data = condition.data(); + auto numel = condition.numel(); + auto dims = condition.dims(); + const int rank = dims.size(); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* true_num = RAII_GUARD.alloc_l3_or_gm(1); + int true_num_cpu; + int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); + PADDLE_ENFORCE_EQ( + ret, + XPU_SUCCESS, + phi::errors::External( + "XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", + ret, + XPUAPIErrorMsg[ret])); + + paddle::memory::Copy(phi::CPUPlace(), + static_cast(&true_num_cpu), + dev_ctx.GetPlace(), + static_cast(true_num), + sizeof(int32_t)); + + out->Resize(phi::make_ddim({static_cast(true_num_cpu), rank})); + auto* out_data = dev_ctx.template Alloc(out); + + if (true_num_cpu == 0) { + return; + } + + auto condition_shape = phi::vectorize(dims); + ret = xpu::where( + dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu); + PADDLE_ENFORCE_EQ(ret, + XPU_SUCCESS, + phi::errors::External( + "XPU masked_select kernel return wrong value[%d %s]", + ret, + XPUAPIErrorMsg[ret])); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + where_index, XPU, ALL_LAYOUT, phi::WhereIndexKernel, int, bool, float) {}