arg_min_max_kernel.cu 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/arg_min_max_kernel.h"

17 18 19 20 21 22 23 24 25 26 27 28 29
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#if defined(__NVCC__) || defined(__HIPCC__)

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <limits>
30

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/ddim.h"

namespace phi {

namespace {  // NOLINT
template <typename K, typename V>
using KeyValuePair = cub::KeyValuePair<K, V>;

}  // end namespace

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)               \
  FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);

template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__ void ArgCUDAKernel(const int64_t height,     // n * h
                              const int64_t width,      // c
                              const int64_t post_size,  // h
                              const Reducer reducer,
                              const T init,
                              const T* in,
                              IndType* out) {
  typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;

  for (int idx = blockIdx.x; idx < height; idx += gridDim.x) {
    KeyValuePair<int, T> kv_pair = {-1, init};
    int h = idx / post_size;
    int w = idx % post_size;
    for (int k = threadIdx.x; k < width; k += blockDim.x) {
      kv_pair =
          reducer({k, in[h * width * post_size + k * post_size + w]}, kv_pair);
    }
    kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
    if (threadIdx.x == 0) {
      out[idx] = static_cast<IndType>(kv_pair.key);
    }
    __syncthreads();
  }
}

template <typename T, typename IndType, class Reducer>
void ComputeFullArg(const phi::GPUContext& dev_ctx,
                    const DenseTensor& input,
                    DenseTensor* indices,
                    const int64_t pre,
                    const int64_t post,
                    const int64_t n) {
  auto cu_stream = dev_ctx.stream();
  auto ComputeBlockSize = [](int64_t col) {
    auto block_size = 8;
    if (col > 512)
      block_size = 1024;
    else if (col > 256)
      block_size = 512;
    else if (col > 128)
      block_size = 256;
    else if (col > 64)
      block_size = 128;
    else if (col > 32)
      block_size = 64;
    else if (col > 16)
      block_size = 32;
    else if (col > 8)
      block_size = 16;
#ifdef __HIPCC__
    block_size = std::min(block_size, 256);
#endif
    return block_size;
  };

  int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int64_t height = pre * post;
  int64_t width = n;
  int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;

  const T* in_data = input.data<T>();
  IndType* out_data = dev_ctx.template Alloc<IndType>(indices);

  if (typeid(Reducer) == typeid(cub::ArgMax)) {
    switch (ComputeBlockSize(width)) {
125 126 127 128 129 130 131 132 133
      FIXED_BLOCK_DIM_CASE(ArgCUDAKernel<T, IndType, Reducer, kBlockDim>
                           <<<grid_size, kBlockDim, 0, cu_stream>>>(
                               height,
                               width,
                               post,
                               Reducer(),
                               std::numeric_limits<T>::lowest(),
                               in_data,
                               out_data));
134 135 136
    }
  } else {
    switch (ComputeBlockSize(width)) {
137 138 139 140 141 142 143 144 145
      FIXED_BLOCK_DIM_CASE(ArgCUDAKernel<T, IndType, Reducer, kBlockDim>
                           <<<grid_size, kBlockDim, 0, cu_stream>>>(
                               height,
                               width,
                               post,
                               Reducer(),
                               std::numeric_limits<T>::max(),
                               in_data,
                               out_data));
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
    }
  }
}

template <typename Context, typename T, class Reducer>
struct VisitDataCudaArgMinMaxFunctor {
  const Context& dev_ctx;
  const DenseTensor& x;
  int64_t axis;
  bool keepdims;
  bool flatten;
  DenseTensor* out;

  explicit VisitDataCudaArgMinMaxFunctor(const Context& dev_ctx,
                                         const DenseTensor& x,
                                         int64_t axis,
                                         bool keepdims,
                                         bool flatten,
                                         DenseTensor* out)
      : dev_ctx(dev_ctx),
        x(x),
        axis(axis),
        keepdims(keepdims),
        flatten(flatten),
        out(out) {}

  template <typename IndType>
  void apply() const {
    phi::DDim x_dims;
    int new_axis = axis;
    if (flatten) {
      x_dims = phi::make_ddim({x.numel()});
      // if flatten, the axis just as 0
      new_axis = 0;
    } else {
      x_dims = x.dims();
      if (axis < 0) new_axis = axis + x.dims().size();
    }

    int64_t numel = x.numel();
    int64_t groups = numel / x_dims[new_axis];
    int64_t pre = 1;
    int64_t post = 1;
    int64_t n = x_dims[new_axis];

    for (int i = 0; i < new_axis; i++) {
      pre *= x_dims[i];
    }

    for (int i = new_axis + 1; i < x_dims.size(); i++) {
      post *= x_dims[i];
    }

    ComputeFullArg<T, IndType, Reducer>(dev_ctx, x, out, pre, post, n);
  }
};

template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
                           const DenseTensor& x,
206
                           const Scalar& axis,
207 208 209 210 211 212 213 214 215
                           bool keepdims,
                           bool flatten,
                           int dtype,
                           DenseTensor* out) {
  if (dtype < 0) {
    paddle::framework::VisitDataTypeTiny(
        static_cast<paddle::framework::proto::VarType::Type>(
            paddle::framework::proto::VarType::INT64),
        VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
216
            dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
217 218 219 220 221
    return;
  }
  paddle::framework::VisitDataTypeTiny(
      static_cast<paddle::framework::proto::VarType::Type>(dtype),
      VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
222
          dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
223 224 225 226 227
}

template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
                  const DenseTensor& x,
228
                  const Scalar& axis,
229 230 231 232 233 234 235 236 237 238 239
                  bool keepdims,
                  bool flatten,
                  int dtype,
                  DenseTensor* out) {
  ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMin>(
      dev_ctx, x, axis, keepdims, flatten, dtype, out);
}

template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
                  const DenseTensor& x,
240
                  const Scalar& axis,
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
                  bool keepdims,
                  bool flatten,
                  int dtype,
                  DenseTensor* out) {
  ArgMinMaxOpCUDAKernel<Context, T, cub::ArgMax>(
      dev_ctx, x, axis, keepdims, flatten, dtype, out);
}

#endif

}  // namespace phi

PD_REGISTER_KERNEL(arg_min,
                   GPU,
                   ALL_LAYOUT,
                   phi::ArgMinKernel,
257
                   phi::dtype::float16,
258
                   phi::dtype::bfloat16,
259 260 261 262 263 264 265 266 267 268 269
                   float,
                   double,
                   int32_t,
                   int64_t,
                   int16_t,
                   uint8_t) {}

PD_REGISTER_KERNEL(arg_max,
                   GPU,
                   ALL_LAYOUT,
                   phi::ArgMaxKernel,
270
                   phi::dtype::float16,
271
                   phi::dtype::bfloat16,
272 273 274 275 276 277
                   float,
                   double,
                   int32_t,
                   int64_t,
                   int16_t,
                   uint8_t) {}