where_index_kernel.cu 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/phi/kernels/funcs/math_function.h"
24 25
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/where_index_kernel.h"
26

27 28 29 30
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

31
namespace phi {
32
template <typename MaskT, typename IndexT, typename OutT>
33
struct IndexFunctor {
34 35 36
  IndexT strides[phi::DDim::kMaxRank];
  int rank;

37
  explicit IndexFunctor(const phi::DDim &in_dims) {
38 39 40 41 42
    rank = in_dims.size();
    // Get strides according to in_dims
    strides[0] = 1;
    for (IndexT i = 1; i < rank; i++) {
      strides[i] = strides[i - 1] * in_dims[rank - i];
43
    }
44 45
  }

46
  HOSTDEVICE inline void operator()(OutT *out,
47 48
                                    const MaskT *mask,
                                    const IndexT *index,
49 50 51 52
                                    const int num) {
    int store_fix = 0;
    for (int idx = 0; idx < num; idx++) {
      if (mask[idx]) {
53
        IndexT data_index = index[idx];
54
        // get index
55 56 57
        for (int rank_id = rank - 1; rank_id >= 0; --rank_id) {
          out[store_fix] = static_cast<OutT>(data_index / strides[rank_id]);
          data_index = data_index % strides[rank_id];
58 59
          store_fix++;
        }
60 61 62
      }
    }
  }
63
};
64 65 66 67 68

template <typename T, typename Context>
void WhereIndexKernel(const Context &dev_ctx,
                      const DenseTensor &condition,
                      DenseTensor *out) {
69
  DenseTensor in_data;
70
  auto dims = condition.dims();
71 72 73 74
  using Functor = IndexFunctor<T, int64_t, int64_t>;
  Functor index_functor = Functor(dims);
  phi::funcs::SelectKernel<T, T, int64_t, 0, Functor>(
      dev_ctx, condition, in_data, out, index_functor);
75 76 77 78 79 80 81 82 83 84 85 86 87
}
}  // namespace phi

PD_REGISTER_KERNEL(where_index,
                   GPU,
                   ALL_LAYOUT,
                   phi::WhereIndexKernel,
                   int64_t,
                   int,
                   int16_t,
                   bool,
                   float,
                   double) {}