nonzero_kernel.cu 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

23 24 25
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
26 27
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
28
#include "paddle/phi/kernels/nonzero_kernel.h"
29

30
namespace phi {
31
template <typename MaskT, typename IndexT, typename OutT>
32
struct IndexFunctor {
33 34 35
  IndexT strides[phi::DDim::kMaxRank];
  int rank;

36
  explicit IndexFunctor(const phi::DDim &in_dims) {
37 38 39 40 41
    rank = in_dims.size();
    // Get strides according to in_dims
    strides[0] = 1;
    for (IndexT i = 1; i < rank; i++) {
      strides[i] = strides[i - 1] * in_dims[rank - i];
42
    }
43 44
  }

45
  HOSTDEVICE inline void operator()(OutT *out,
46 47
                                    const MaskT *mask,
                                    const IndexT *index,
48 49 50 51
                                    const int num) {
    int store_fix = 0;
    for (int idx = 0; idx < num; idx++) {
      if (mask[idx]) {
52
        IndexT data_index = index[idx];
53
        // get index
54 55 56
        for (int rank_id = rank - 1; rank_id >= 0; --rank_id) {
          out[store_fix] = static_cast<OutT>(data_index / strides[rank_id]);
          data_index = data_index % strides[rank_id];
57 58
          store_fix++;
        }
59 60 61
      }
    }
  }
62
};
63 64

template <typename T, typename Context>
65 66 67
void NonZeroKernel(const Context &dev_ctx,
                   const DenseTensor &condition,
                   DenseTensor *out) {
68
  DenseTensor in_data;
69
  auto dims = condition.dims();
70 71 72 73
  using Functor = IndexFunctor<T, int64_t, int64_t>;
  Functor index_functor = Functor(dims);
  phi::funcs::SelectKernel<T, T, int64_t, 0, Functor>(
      dev_ctx, condition, in_data, out, index_functor);
74 75 76
}
}  // namespace phi

77
PD_REGISTER_KERNEL(nonzero,
78 79
                   GPU,
                   ALL_LAYOUT,
80
                   phi::NonZeroKernel,
81 82 83 84 85
                   int64_t,
                   int,
                   int16_t,
                   bool,
                   float,
86 87 88
                   double) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}