top_k_kernel.cu 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/top_k_kernel.h"

17 18 19
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
20
#include "paddle/phi/core/tensor_utils.h"
21
#include "paddle/phi/kernels/funcs/gather.cu.h"
22 23 24 25 26 27 28 29 30 31 32 33
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

namespace ops = paddle::operators;

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

34 35 36 37 38 39 40 41 42 43 44 45
#define FIXED_MAXLENGTH_BASE(MaxLength, ...) \
  case (MaxLength): {                        \
    constexpr auto maxLength = (MaxLength);  \
    __VA_ARGS__;                             \
  } break

#define FIXED_BLOCK_DIM(...)                 \
  FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);   \
46 47
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

48 49 50 51 52 53 54
#define FIXED_MAXLENGTH(...)              \
  FIXED_MAXLENGTH_BASE(1, ##__VA_ARGS__); \
  FIXED_MAXLENGTH_BASE(2, ##__VA_ARGS__); \
  FIXED_MAXLENGTH_BASE(3, ##__VA_ARGS__); \
  FIXED_MAXLENGTH_BASE(4, ##__VA_ARGS__); \
  FIXED_MAXLENGTH_BASE(5, ##__VA_ARGS__)

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
template <typename T, typename Context>
void TopkKernel(const Context& dev_ctx,
                const DenseTensor& x,
                const Scalar& k_scalar,
                int axis,
                bool largest,
                bool sorted,
                DenseTensor* out,
                DenseTensor* indices) {
  const auto* input = &x;
  // get the input dims
  const auto& in_dims = input->dims();
  // calcluate the real axis
  if (axis < 0) axis += in_dims.size();

  int k = k_scalar.to<int>();
  if (k_scalar.FromTensor()) {
    phi::DDim out_dims = out->dims();
    out_dims[axis] = k;
    out->Resize(out_dims);
    indices->Resize(out_dims);
  }

  const auto& out_dims = out->dims();

  const T* input_data = input->data<T>();
  T* output_data = dev_ctx.template Alloc<T>(out);
  int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);

  if (axis == in_dims.size() - 1) {
    // if get the topK from the last axis
    const int64_t& input_height =
        phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
    const int64_t& input_width = in_dims[in_dims.size() - 1];

    if (k > input_width) {
      k = input_width;
    }

    // The conclusion is drawn from the data through multiple sets of
    // statistics
    if (input_width >= 128 && k >= input_width * 0.75) {
L
Leo Chen 已提交
97
      auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
98 99 100 101 102 103 104 105
      if (ops::SortTopk<T>(*ctx,
                           input,
                           input_width,
                           input_height,
                           k,
                           out,
                           indices,
                           largest)) {
106 107 108
        // Successed, return.
        return;
      } else {
109 110 111 112 113 114
        VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
                   "default topk kernel.";
      }
    }

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
115
    if (input_width >= 1024 && in_dims.size() == 1) {
116 117 118
      // 1. Gather TopK, but without sorting
      constexpr int max_num_threads = 1024;
      if (largest) {
119 120 121 122 123 124 125 126
        ops::RadixTopK<T, true>
            <<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
                input_data,
                k,
                input_height,
                input_width,
                output_data,
                indices_data);
127
      } else {
128 129 130 131 132 133 134 135
        ops::RadixTopK<T, false>
            <<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
                input_data,
                k,
                input_height,
                input_width,
                output_data,
                indices_data);
136 137 138 139 140 141 142 143 144 145 146 147
      }
      // 2. Sort if needed
      if (sorted) {
        DenseTensor sorted_output;
        DenseTensor sorted_indices;
        DenseTensor gather_indices;
        sorted_output.Resize(out->dims());
        sorted_indices.Resize(indices->dims());
        gather_indices.Resize(indices->dims());
        dev_ctx.template Alloc<T>(&sorted_output);
        dev_ctx.template Alloc<int64_t>(&sorted_indices);
        dev_ctx.template Alloc<int64_t>(&gather_indices);
L
Leo Chen 已提交
148
        auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        if (ops::SortTopk<T>(*ctx,
                             out,
                             k,
                             input_height,
                             k,
                             &sorted_output,
                             &sorted_indices,
                             largest)) {
          funcs::GPUGather<int64_t, int64_t>(
              dev_ctx, *indices, sorted_indices, &gather_indices);
          Copy(dev_ctx, gather_indices, indices->place(), false, indices);
          Copy(dev_ctx, sorted_output, out->place(), false, out);
          return;
        } else {
          VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
164
                     "default topk kernel.";
165 166 167
        }
      } else {
        return;
168 169
      }
    }
170
#endif
171 172 173 174 175

    // NOTE: pass lds and dim same to input width.
    // NOTE: old matrix implementation of stride is different to eigen.
    const int kMaxHeight = 2048;
    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
176 177 178
    paddle::platform::GpuLaunchConfig config =
        paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
    switch (config.thread_per_block.x) {
179
#ifdef PADDLE_WITH_HIP
180 181 182 183 184 185 186 187 188 189 190 191
      FIXED_BLOCK_DIM(
          ops::KeMatrixTopK<T, 20, kBlockDim>
          <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
                                                      k,
                                                      indices_data,
                                                      input_data,
                                                      input_width,
                                                      input_width,
                                                      static_cast<int>(k),
                                                      gridx,
                                                      input_height,
                                                      largest));
192
#else
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
      FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
        FIXED_MAXLENGTH(
            ops::KeMatrixTopK<T, maxLength, kBlockDim>
            <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
                                                        k,
                                                        indices_data,
                                                        input_data,
                                                        input_width,
                                                        input_width,
                                                        static_cast<int>(k),
                                                        gridx,
                                                        input_height,
                                                        largest));
        default:
          PADDLE_THROW(
              errors::Fatal("the input k has error in the topk cuda kernel."));
      });
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
#endif
      default:
        PADDLE_THROW(errors::Fatal(
            "the input data shape has error in the topk cuda kernel."));
    }
  } else {
    // if get topK not from the last axis, will tranpose the tensor and get
    // TopK

    // first step, prepare the trans args for the tranpose
    std::vector<int> trans;
    for (int i = 0; i < axis; i++) {
      trans.emplace_back(i);
    }
    trans.emplace_back(in_dims.size() - 1);
    for (int i = axis + 1; i < in_dims.size() - 1; i++) {
      trans.emplace_back(i);
    }
    trans.emplace_back(axis);

    phi::DDim trans_dims(in_dims);
    phi::DDim trans_out_dims(out->dims());
    for (int i = 0; i < trans.size(); i++) {
      trans_dims[i] = in_dims[trans[i]];
      trans_out_dims[i] = out_dims[trans[i]];
    }
    // second step, tranpose the input
    DenseTensor trans_input;
    trans_input.Resize(trans_dims);
    dev_ctx.template Alloc<T>(&trans_input);
    int ndims = trans.size();
    funcs::TransCompute<phi::GPUContext, T>(
        ndims, dev_ctx, *input, &trans_input, trans);
    // third step, calcluate the topk
    // allocate the tmp cuda memory for the tmp result
    DenseTensor trans_ind;
    DenseTensor trans_out;
    trans_ind.Resize(trans_out_dims);
    trans_out.Resize(trans_out_dims);
    dev_ctx.template Alloc<int64_t>(&trans_ind);
    dev_ctx.template Alloc<T>(&trans_out);

    const int64_t input_height =
        phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
    const int64_t input_width = trans_dims[trans_dims.size() - 1];

    if (k > input_width) k = input_width;

    // The conclusion is drawn from the data through multiple sets of
    // statistics
    if (input_width >= 128 && k >= input_width * 0.75) {
L
Leo Chen 已提交
261
      auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
262 263 264 265 266 267 268 269
      if (ops::SortTopk<T>(*ctx,
                           &trans_input,
                           input_width,
                           input_height,
                           k,
                           &trans_out,
                           &trans_ind,
                           largest)) {
270 271 272 273 274 275 276
        // last step, tranpose back the indices and output
        funcs::TransCompute<phi::GPUContext, int64_t>(
            ndims, dev_ctx, trans_ind, indices, trans);
        funcs::TransCompute<phi::GPUContext, T>(
            ndims, dev_ctx, trans_out, out, trans);
        return;
      } else {
277 278
        VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
                   "default topk kernel.";
279 280 281 282 283
      }
    }

    const int kMaxHeight = 2048;
    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
284 285 286
    paddle::platform::GpuLaunchConfig config =
        paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
    switch (config.thread_per_block.x) {
287
#ifdef PADDLE_WITH_HIP
288 289 290 291 292 293 294 295 296 297 298 299
      FIXED_BLOCK_DIM(
          ops::KeMatrixTopK<T, 20, kBlockDim>
          <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(trans_out.data<T>(),
                                                      k,
                                                      trans_ind.data<int64_t>(),
                                                      trans_input.data<T>(),
                                                      input_width,
                                                      input_width,
                                                      static_cast<int>(k),
                                                      gridx,
                                                      input_height,
                                                      largest));
300
#else
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
      FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
        FIXED_MAXLENGTH(ops::KeMatrixTopK<T, maxLength, kBlockDim>
                        <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
                            trans_out.data<T>(),
                            k,
                            trans_ind.data<int64_t>(),
                            trans_input.data<T>(),
                            input_width,
                            input_width,
                            static_cast<int>(k),
                            gridx,
                            input_height,
                            largest));
        default:
          PADDLE_THROW(
              errors::Fatal("the input k has error in the topk cuda kernel."));
      });
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
#endif
      default:
        PADDLE_THROW(errors::Fatal(
            "the input data shape has error in the topk cuda kernel."));
    }

    // last step, tranpose back the indices and output
    funcs::TransCompute<phi::GPUContext, int64_t>(
        ndims, dev_ctx, trans_ind, indices, trans);
    funcs::TransCompute<phi::GPUContext, T>(
        ndims, dev_ctx, trans_out, out, trans);
  }
}
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM

}  // namespace phi

PD_REGISTER_KERNEL(top_k,
                   GPU,
                   ALL_LAYOUT,
                   phi::TopkKernel,
                   float,
                   double,
                   int,
                   int64_t,
                   phi::dtype::float16) {}