top_k_function_cuda.h 35.9 KB
Newer Older
W
wawltor 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <stdio.h>
17

W
wawltor 已提交
18 19
#include <cstdio>
#include <vector>
20
#ifdef __NVCC__
W
wawltor 已提交
21
#include "cub/cub.cuh"
22 23 24 25
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
26
#include "paddle/fluid/operators/eigen/eigen_function.h"
27
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
W
wawltor 已提交
28
#include "paddle/fluid/operators/top_k_op.h"
29
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
30
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
31
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
W
wawltor 已提交
32 33
#include "paddle/fluid/platform/float16.h"

34
#define FINAL_MASK 0xffffffff
35 36 37 38 39 40 41 42 43 44
#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<paddle::platform::float16>
    : radix_key_codec_integral<paddle::platform::float16, uint16_t> {};
}  // namespace detail
}  // namespace rocprim
namespace cub = hipcub;
#else
W
wawltor 已提交
45 46 47 48
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
49 50 51 52
    : BaseTraits<FLOATING_POINT,
                 true,
                 false,
                 uint16_t,
W
wawltor 已提交
53 54
                 paddle::platform::float16> {};
}  // namespace cub
55
#endif
W
wawltor 已提交
56 57 58 59

namespace paddle {
namespace operators {

60
using Tensor = phi::DenseTensor;
W
wawltor 已提交
61

62 63
inline void GetDims(
    const phi::DDim& dim, int axis, int* pre, int* n, int* post) {
64 65 66 67 68 69 70 71 72 73 74
  *pre = 1;
  *post = 1;
  *n = dim[axis];
  for (int i = 0; i < axis; ++i) {
    (*pre) *= dim[i];
  }
  for (int i = axis + 1; i < dim.size(); ++i) {
    (*post) *= dim[i];
  }
}

W
wawltor 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
struct SegmentOffsetIter {
  EIGEN_DEVICE_FUNC
  explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    return idx * num_cols_;
  }

  int num_cols_;
};

// Iter using into a column
struct ColumnIndexIter {
  explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
      const Eigen::array<int, 1>& ix) const {
    return ix[0] % num_cols_;
  }

  int num_cols_;
};

inline static int GetDesiredBlockDim(int dim) {
  if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
  }
}

110 111 112 113 114 115 116 117
inline static int getMaxLength(int k) {
  if (k / 5 < 1) {
    return 1;
  } else if (k / 5 >= 1) {
    return min(k / 5, 5);
  }
}

W
wawltor 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
template <typename T>
__global__ void InitIndex(T* indices, T num_rows, T num_cols) {
  int col_id = threadIdx.x;
  int row_id = blockIdx.x;

  for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
    for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
      indices[j * num_cols + i] = i;
    }
  }
}

template <typename T>
struct Pair {
  __device__ __forceinline__ Pair() {}
  __device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}

  __device__ __forceinline__ void set(T value, int64_t id) {
    v = value;
    id = id;
  }

  __device__ __forceinline__ void operator=(const Pair<T>& in) {
    v = in.v;
    id = in.id;
  }

  __device__ __forceinline__ bool operator<(const T value) const {
    return (v < value);
  }

  __device__ __forceinline__ bool operator>(const T value) const {
    return (v > value);
  }
  __device__ __forceinline__ bool operator<(const Pair<T>& in) const {
    return (v < in.v) || ((v == in.v) && (id > in.id));
  }

  __device__ __forceinline__ bool operator>(const Pair<T>& in) const {
    return (v > in.v) || ((v == in.v) && (id < in.id));
  }

  T v;
  int64_t id;
};

template <typename T>
165 166 167 168
__device__ __forceinline__ void AddTo(Pair<T> topk[],
                                      const Pair<T>& p,
                                      int beam_size,
                                      const bool& largest) {
W
wawltor 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
  for (int k = beam_size - 2; k >= 0; k--) {
    if (largest) {
      if (topk[k] < p) {
        topk[k + 1] = topk[k];
      } else {
        topk[k + 1] = p;
        return;
      }
    } else {
      if (topk[k] > p) {
        topk[k + 1] = topk[k];
      } else {
        topk[k + 1] = p;
        return;
      }
    }
  }
  topk[0] = p;
}

template <typename T, int BlockSize>
190 191 192 193 194
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
                                        const T* src,
                                        int idx,
                                        int dim,
                                        int beam_size,
W
wawltor 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
                                        const bool& largest) {
  while (idx < dim) {
    if (largest) {
      if (topk[beam_size - 1] < src[idx]) {
        Pair<T> tmp(src[idx], idx);
        AddTo<T>(topk, tmp, beam_size, largest);
      }
    } else {
      if (topk[beam_size - 1] > src[idx]) {
        Pair<T> tmp(src[idx], idx);
        AddTo<T>(topk, tmp, beam_size, largest);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
213 214 215 216 217 218 219
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
                                        const T* src,
                                        int idx,
                                        int dim,
                                        const Pair<T>& max,
                                        int beam_size,
                                        const bool& largest) {
W
wawltor 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
  while (idx < dim) {
    if (largest) {
      if (topk[beam_size - 1] < src[idx]) {
        Pair<T> tmp(src[idx], idx);
        if (tmp < max) {
          AddTo<T>(topk, tmp, beam_size, largest);
        }
      }
    } else {
      if (topk[beam_size - 1] > src[idx]) {
        Pair<T> tmp(src[idx], idx);
        if (tmp > max) {
          AddTo<T>(topk, tmp, beam_size, largest);
        }
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int MaxLength, int BlockSize>
241 242 243 244 245 246 247 248 249 250
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
                                              int* beam,
                                              int beam_size,
                                              const T* src,
                                              bool* firstStep,
                                              bool* is_empty,
                                              Pair<T>* max,
                                              int dim,
                                              const int tid,
                                              bool largest) {
W
wawltor 已提交
251 252 253 254 255 256 257 258 259 260
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
      GetTopK<T, BlockSize>(topk, src, tid, dim, length, largest);
    } else {
      for (int k = 0; k < MaxLength; k++) {
        if (k < MaxLength - (*beam)) {
          topk[k] = topk[k + *beam];
        } else {
261 262 263 264 265
          if (largest) {
            topk[k].set(-static_cast<T>(INFINITY), -1);
          } else {
            topk[k].set(static_cast<T>(INFINITY), -1);
          }
W
wawltor 已提交
266 267 268
        }
      }
      if (!(*is_empty)) {
269 270
        GetTopK<T, BlockSize>(
            topk + MaxLength - *beam, src, tid, dim, *max, length, largest);
W
wawltor 已提交
271 272 273 274
      }
    }

    *max = topk[MaxLength - 1];
275
    if ((*max).id == -1) *is_empty = true;
W
wawltor 已提交
276 277 278 279
    *beam = 0;
  }
}

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
template <typename T>
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
                                              const bool& largest) {
  if (largest) {
#pragma unroll
    for (int offset = 16; offset > 0; offset >>= 1) {
      T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
      int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
      if (input.v < tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
        input.v = tmp_val;
        input.id = tmp_id;
      }
    }
  } else {
#pragma unroll
    for (int offset = 16; offset > 0; offset >>= 1) {
      T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
      int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
      if (input.v > tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
        input.v = tmp_val;
        input.id = tmp_id;
      }
    }
  }
  return input;
}

W
wawltor 已提交
307
template <typename T, int MaxLength, int BlockSize>
308
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
309 310 311 312 313 314
                                            Pair<T> topk[],
                                            T** topVal,
                                            int64_t** topIds,
                                            int* beam,
                                            int* k,
                                            const int tid,
315 316
                                            const int wid,
                                            const int lane,
W
wawltor 已提交
317 318 319
                                            const bool& largest) {
  while (true) {
    __syncthreads();
320 321 322 323 324
    Pair<T> input_now = topk[0];
    input_now = WarpReduce(input_now, largest);

    if (lane == 0) {
      shared_max[wid] = input_now;
W
wawltor 已提交
325 326
    }
    __syncthreads();
327 328 329 330 331 332 333 334 335 336 337 338
    if (largest) {
      input_now = (tid < BlockSize / 32)
                      ? shared_max[lane]
                      : Pair<T>(-static_cast<T>(INFINITY), -1);
    } else {
      input_now = (tid < BlockSize / 32)
                      ? shared_max[lane]
                      : Pair<T>(static_cast<T>(INFINITY), -1);
    }
    if (wid == 0) {
      input_now = WarpReduce(input_now, largest);
      if (lane == 0) shared_max[0] = input_now;
W
wawltor 已提交
339 340 341 342
    }
    __syncthreads();

    if (tid == 0) {
343 344
      **topVal = input_now.v;
      **topIds = input_now.id;
W
wawltor 已提交
345 346 347
      (*topVal)++;
      (*topIds)++;
    }
348 349 350
    int tid_max = shared_max[0].id % BlockSize;
    if (tid == tid_max) {
      (*beam)++;
W
wawltor 已提交
351
      if (*beam < MaxLength) {
352
        topk[0] = topk[*beam];
W
wawltor 已提交
353 354
      }
    }
355 356
    if (--(*k) == 0) break;

357 358 359 360 361
    unsigned mask = 0u;
    CREATE_SHFL_MASK(mask, true);
    if (tid_max / 32 == wid) {
      if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
        break;
W
wawltor 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375
    }
  }
}

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top MaxLength value;
 * 2. merge to sh_topk, block reduce and get max value;
 * 3. go to the second setp, until one thread's topk value is null;
 * 4. go to the first setp, until get the topk value.
 */

template <typename T, int MaxLength, int BlockSize>
376 377 378 379 380 381 382 383 384 385
__global__ void KeMatrixTopK(T* output,
                             int output_stride,
                             int64_t* indices,
                             const T* src,
                             int lds,
                             int dim,
                             int k,
                             int grid_dim,
                             int num,
                             bool largest = true) {
W
wawltor 已提交
386
  const int tid = threadIdx.x;
387 388
  const int wid = tid / 32;
  const int lane = tid % 32;
W
wawltor 已提交
389 390 391
  const int bid = blockIdx.x;
  for (int i = bid; i < num; i += grid_dim) {
    int top_num = k;
392
    __shared__ Pair<T> shared_max[BlockSize / 32];
W
wawltor 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    T* out = output + i * output_stride;
    int64_t* inds = indices + i * k;
    Pair<T> topk[MaxLength];
    int beam = MaxLength;
    Pair<T> max;
    bool is_empty = false;
    bool firststep = true;

    for (int j = 0; j < MaxLength; j++) {
      if (largest) {
        topk[j].set(-static_cast<T>(INFINITY), -1);
      } else {
        topk[j].set(static_cast<T>(INFINITY), -1);
      }
    }
    while (top_num) {
409 410 411 412 413 414 415 416 417 418
      ThreadGetTopK<T, MaxLength, BlockSize>(topk,
                                             &beam,
                                             k,
                                             src + i * lds,
                                             &firststep,
                                             &is_empty,
                                             &max,
                                             dim,
                                             tid,
                                             largest);
419
      BlockReduce<T, MaxLength, BlockSize>(shared_max,
420 421 422 423 424 425
                                           topk,
                                           &out,
                                           &inds,
                                           &beam,
                                           &top_num,
                                           tid,
426 427
                                           wid,
                                           lane,
428
                                           largest);
W
wawltor 已提交
429 430 431 432
    }
  }
}

433
/*---------------------------Radix TopK Begin------------------*/
434
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
435 436 437 438 439 440 441 442 443 444 445
constexpr int RADIX_BITS = 2;  // digits are base-(2 ^ RADIX_BITS)
constexpr int RADIX_SIZE = 4;  // 2 ^ RADIX_BITS
constexpr int RADIX_MASK = (RADIX_SIZE - 1);

/*---------------------------Helper Structs------------------*/
template <typename T>
struct Bitfield {};

template <>
struct Bitfield<unsigned int> {
  static __device__ __forceinline__ unsigned int GetBitfield(unsigned int val,
446 447
                                                             int pos,
                                                             int len) {
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
    unsigned int ret;
    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
    return ret;
  }

  static __device__ __forceinline__ unsigned int SetBitfield(
      unsigned int val, unsigned int to_insert, int pos, int len) {
    unsigned int ret;
    asm("bfi.b32 %0, %1, %2, %3, %4;"
        : "=r"(ret)
        : "r"(to_insert), "r"(val), "r"(pos), "r"(len));
    return ret;
  }
};

template <>
struct Bitfield<uint64_t> {
465 466
  static __device__ __forceinline__ uint64_t GetBitfield(uint64_t val,
                                                         int pos,
467 468 469 470 471 472 473 474
                                                         int len) {
    uint64_t ret;
    asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
    return ret;
  }

  static __device__ __forceinline__ uint64_t SetBitfield(uint64_t val,
                                                         uint64_t to_insert,
475 476
                                                         int pos,
                                                         int len) {
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
    uint64_t ret;
    asm("bfi.b64 %0, %1, %2, %3, %4;"
        : "=l"(ret)
        : "l"(to_insert), "l"(val), "r"(pos), "r"(len));
    return ret;
  }
};

template <typename T>
struct RadixTypeConfig {};

template <>
struct RadixTypeConfig<float> {
  typedef uint32_t RadixType;

  static inline __device__ RadixType Convert(float v) {
    RadixType x = __float_as_int(v);
    RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;

    return (v == v) ? (x ^ mask) : 0xffffffff;
  }

  static inline __device__ float Deconvert(RadixType v) {
    RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;

    return __int_as_float(v ^ mask);
  }
};

template <>
struct RadixTypeConfig<double> {
  typedef uint64_t RadixType;

  static inline __device__ RadixType Convert(double v) {
    RadixType x = __double_as_longlong(v);
    RadixType mask = -((x >> 63)) | 0x8000000000000000;
    return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
  }

  static inline __device__ double Deconvert(RadixType v) {
    RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
    return __longlong_as_double(v ^ mask);
  }
};

template <>
struct RadixTypeConfig<int32_t> {
  typedef uint32_t RadixType;

  static inline __device__ RadixType Convert(int32_t v) {
    static_assert(sizeof(int) == 4, "");
    return 2147483648u + v;
  }

  static inline __device__ int32_t Deconvert(RadixType v) {
    return v - 2147483648u;
  }
};

template <>
struct RadixTypeConfig<int64_t> {
  typedef uint64_t RadixType;

  static inline __device__ RadixType Convert(int64_t v) {
    static_assert(sizeof(int64_t) == 8, "");
    return 9223372036854775808ull + v;
  }

  static inline __device__ int64_t Deconvert(RadixType v) {
    return v - 9223372036854775808ull;
  }
};

template <>
struct RadixTypeConfig<platform::float16> {
  typedef uint32_t RadixType;

  static inline __device__ RadixType Convert(platform::float16 v) {
555
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
556 557 558 559
    half v_h = v.to_half();
    RadixType x = __half_as_ushort(v_h);
    RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
    return (v_h == v_h) ? (x ^ mask) : 0xffff;
560 561 562 563
#else
    assert(false);
    return 0u;
#endif
564 565 566
  }

  static inline __device__ platform::float16 Deconvert(RadixType v) {
567
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
568 569
    RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
    return static_cast<platform::float16>(__ushort_as_half(v ^ mask));
570 571 572 573
#else
    assert(false);
    return static_cast<platform::float16>(0);
#endif
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
  }
};

/*---------------------------Helper Functions------------------*/
__device__ __forceinline__ int GetLaneId() {
  int lane_id;
  asm("mov.s32 %0, %%laneid;" : "=r"(lane_id));
  return lane_id;
}

__device__ __forceinline__ unsigned GetLaneMaskLe() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
  return mask;
}

template <typename T, bool KillDependency, class Function>
591 592 593
__device__ void InclusiveBinaryPrefixScan(T* shared_mem,
                                          bool in,
                                          T* out,
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
                                          Function func) {
  T vote = __ballot_sync(__activemask(), in);
  T index = __popc(GetLaneMaskLe() & vote);
  T carry = __popc(vote);

  int warp = threadIdx.x / 32;

  if (GetLaneId() == 0) {
    shared_mem[warp] = carry;
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    int current = 0;
    for (int i = 0; i < blockDim.x / 32; ++i) {
      T v = shared_mem[i];
      shared_mem[i] = func(shared_mem[i], current);
      current = func(current, v);
    }
  }

  __syncthreads();

  if (warp >= 1) {
    index = func(index, shared_mem[warp - 1]);
  }

  *out = index;

  if (KillDependency) {
    __syncthreads();
  }
}

template <typename T, bool KillDependency, class Function>
630 631
__device__ void ExclusiveBinaryPrefixScan(
    T* shared_mem, bool in, T* out, T* carry, Function func) {
632 633 634 635 636 637 638 639 640 641 642 643
  InclusiveBinaryPrefixScan<T, false, Function>(shared_mem, in, out, func);

  *out -= (T)in;

  *carry = shared_mem[(blockDim.x + 31) / 32 - 1];

  if (KillDependency) {
    __syncthreads();
  }
}

template <typename T, typename RadixType>
644 645 646 647 648
__device__ T FindPattern(const T* input,
                         T* shared_mem,
                         int slice_size,
                         RadixType desired,
                         RadixType desired_mask) {
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
  if (threadIdx.x < 2) {
    shared_mem[threadIdx.x] = static_cast<T>(0);
  }
  __syncthreads();

  int block_dim = static_cast<int>(blockDim.x);
  int loop = ((slice_size + block_dim - 1) / block_dim * block_dim);
  for (int i = threadIdx.x; i < loop; i += blockDim.x) {
    bool valid = (i < slice_size);
    T v = valid ? input[i] : static_cast<T>(0);

    if (valid && ((RadixTypeConfig<T>::Convert(v) & desired_mask) == desired)) {
      shared_mem[0] = static_cast<T>(1);
      shared_mem[1] = v;
    }

    __syncthreads();

    T found = shared_mem[0];
    T val = shared_mem[1];

    __syncthreads();

    if (found != static_cast<T>(0)) {
      return val;
    }
  }

  assert(false);
  return static_cast<T>(0);
}

template <typename T, typename RadixType, int RadixSize, int RadixBits>
682 683 684 685 686 687
__device__ void RadixCountUsingMask(const T* input,
                                    int counts[RadixSize],
                                    int* shared_mem,
                                    RadixType desired,
                                    RadixType desired_mask,
                                    int radix_digit_pos,
688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
                                    int slice_size) {
#pragma unroll
  for (int i = 0; i < RadixSize; ++i) {
    counts[i] = 0;
  }

  if (threadIdx.x < RadixSize) {
    shared_mem[threadIdx.x] = 0;
  }
  __syncthreads();

  for (int i = threadIdx.x; i < slice_size; i += blockDim.x) {
    RadixType val = RadixTypeConfig<T>::Convert(input[i]);

    bool has_val = ((val & desired_mask) == desired);
    RadixType digit_in_radix =
        Bitfield<RadixType>::GetBitfield(val, radix_digit_pos, RadixBits);

#pragma unroll
    for (uint32_t j = 0; j < RadixSize; ++j) {
      bool vote = has_val && (digit_in_radix == j);
      counts[j] += __popc(__ballot_sync(__activemask(), vote));
    }
  }

  if (GetLaneId() == 0) {
#pragma unroll
    for (uint32_t i = 0; i < RadixSize; ++i) {
      platform::CudaAtomicAdd(&shared_mem[i], counts[i]);
    }
  }

  __syncthreads();

#pragma unroll
  for (uint32_t i = 0; i < RadixSize; ++i) {
    counts[i] = shared_mem[i];
  }

  __syncthreads();
}

template <typename T, typename RadixType, bool Largest>
731 732
__device__ void RadixSearch(
    const T* input, int k, int slice_size, int* shared_mem, T* kth_value) {
733 734 735 736 737 738 739 740 741 742
  int counts[RADIX_SIZE];

  RadixType desired = 0;
  RadixType desired_mask = 0;

  int k_left = k;

#pragma unroll
  for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0;
       digit_pos -= RADIX_BITS) {
743 744 745 746 747 748 749
    RadixCountUsingMask<T, RadixType, RADIX_SIZE, RADIX_BITS>(input,
                                                              counts,
                                                              shared_mem,
                                                              desired,
                                                              desired_mask,
                                                              digit_pos,
                                                              slice_size);
750 751 752 753 754 755 756 757

    auto found_unique = [&](int i, int count) -> bool {
      if (count == 1 && k_left == 1) {
        desired =
            Bitfield<RadixType>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
        desired_mask = Bitfield<RadixType>::SetBitfield(
            desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);

758 759 760 761 762
        *kth_value = FindPattern<T, RadixType>(input,
                                               reinterpret_cast<T*>(shared_mem),
                                               slice_size,
                                               desired,
                                               desired_mask);
763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
        return true;
      }
      return false;
    };
    auto found_non_unique = [&](int i, int count) -> bool {
      if (count >= k_left) {
        desired =
            Bitfield<RadixType>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
        desired_mask = Bitfield<RadixType>::SetBitfield(
            desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);

        return true;
      }
      k_left -= count;
      return false;
    };

    if (Largest) {
// Descending order
#pragma unroll
      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
        int count = counts[i];
        if (found_unique(i, count)) {
          return;
        }
        if (found_non_unique(i, count)) {
          break;
        }
      }
    } else {
// Ascending order
#pragma unroll
      for (int i = 0; i < RADIX_SIZE; ++i) {
        int count = counts[i];
        if (found_unique(i, count)) {
          return;
        }
        if (found_non_unique(i, count)) {
          break;
        }
      }
    }
  }

  *kth_value = RadixTypeConfig<T>::Deconvert(desired);
}

template <typename T, bool Largest>
811 812 813 814 815 816
__global__ void RadixTopK(const T* input,
                          int k,
                          int slice_num,
                          int slice_size,
                          T* output,
                          int64_t* indices) {
817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
  namespace kps = paddle::operators::kernel_primitives;
  __shared__ int shared_mem[32];

  // 1. Find the k-th value
  T kth_value = static_cast<T>(0);
  RadixSearch<T, typename RadixTypeConfig<T>::RadixType, Largest>(
      input, k, slice_size, shared_mem, &kth_value);
  const auto converted_kth_value = RadixTypeConfig<T>::Convert(kth_value);

  // 2. Select the value strictly less/greater than kth_value and their indices
  int block_dim = static_cast<int>(blockDim.x);
  int loop = ((slice_size + block_dim - 1) / block_dim * block_dim);
  int write_start = 0;

  for (int i = threadIdx.x; i < loop; i += blockDim.x) {
    bool valid = i < slice_size;
    T v = valid ? input[i] : static_cast<T>(0);
    const auto convertd_v = RadixTypeConfig<T>::Convert(v);
    bool is_top_k;
    if (Largest) {
      is_top_k = valid && (convertd_v > converted_kth_value);
    } else {
      is_top_k = valid && (convertd_v < converted_kth_value);
    }

    int index;
    int carry;
    ExclusiveBinaryPrefixScan<int, true, kps::AddFunctor<int>>(
        shared_mem, is_top_k, &index, &carry, kps::AddFunctor<int>());
    if (is_top_k) {
      int write_index = write_start + index;
      output[write_index] = v;
      indices[write_index] = i;
    }
    write_start += carry;
  }

  // 3. Fill the rest with value == kth_value
  assert(k >= write_start);
  int remain = k - write_start;
  for (int i = threadIdx.x; i < loop; i += blockDim.x) {
    bool valid = i < slice_size;
    T v = valid ? input[i] : static_cast<T>(0);
    const auto convertd_v = RadixTypeConfig<T>::Convert(v);
    bool is_top_k = valid && (convertd_v == converted_kth_value);

    int index;
    int carry;
    ExclusiveBinaryPrefixScan<int, true, kps::AddFunctor<int>>(
        shared_mem, is_top_k, &index, &carry, kps::AddFunctor<int>());
    if (is_top_k && index < remain) {
      int write_index = write_start + index;
      assert(write_index < k);
      output[write_index] = v;
      indices[write_index] = i;
    }

    if (carry >= remain) {
      break;
    }

    remain -= carry;
    write_start += carry;
  }
}
#endif
/*---------------------------Radix TopK End------------------*/

W
wawltor 已提交
885
template <typename T, int MaxLength, int BlockSize>
886 887 888 889 890 891
__global__ void AssignGrad(T* x_grad,
                           const int64_t* indices,
                           const T* out_grad,
                           size_t rows,
                           size_t cols,
                           size_t k) {
W
wawltor 已提交
892 893 894 895
  for (size_t i = 0; i < rows; ++i) {
    for (size_t j = 0; j < cols; ++j) {
      x_grad[i * cols + j] = 0;
    }
W
wawltor 已提交
896
    __syncthreads();
W
wawltor 已提交
897 898 899 900 901 902 903 904 905
    for (size_t j = 0; j < k; ++j) {
      size_t idx = indices[i * k + j];
      x_grad[i * cols + idx] = out_grad[i * k + j];
    }
  }
}

// the grad assign with the axis
template <typename T>
906 907 908 909 910 911 912
__global__ void AssignGradWithAxis(const T* grad_out,
                                   const int64_t* indices,
                                   T* grad_in,
                                   int pre,
                                   int post,
                                   int raw_height,
                                   int k) {
W
wawltor 已提交
913 914
  // raw_height is the length of topk axis
  for (int i = blockIdx.x; i < pre; i += gridDim.x) {
W
wawltor 已提交
915 916
    int base_index = i * post * k;
    int base_grad = i * post * raw_height;
W
wawltor 已提交
917 918 919
    for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) {
      grad_in[base_grad + j] = static_cast<T>(0);
    }
W
wawltor 已提交
920
    __syncthreads();
W
wawltor 已提交
921
    for (int j = threadIdx.x; j < k * post; j += blockDim.x) {
W
wawltor 已提交
922 923 924
      int64_t idx_ij = indices[base_index + j];
      int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
      grad_in[in_ij] = grad_out[base_index + j];
W
wawltor 已提交
925 926 927 928 929
    }
  }
}
// use the radix sort for the topk
template <typename T>
L
Leo Chen 已提交
930
bool SortTopk(const phi::GPUContext& ctx,
931
              const phi::DenseTensor* input_tensor,
932 933 934
              const int64_t num_cols,
              const int64_t num_rows,
              const int k,
935 936
              phi::DenseTensor* out_tensor,
              phi::DenseTensor* indices_tensor,
W
wawltor 已提交
937 938 939 940 941
              bool largest = true) {
  auto cu_stream = ctx.stream();

  Tensor input_indices;
  const std::vector<int64_t> dims = {num_rows, num_cols};
942
  auto dim = phi::make_ddim(dims);
W
wawltor 已提交
943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961
  input_indices.Resize(dim);
  // input_indices.Resize(num_rows*num_cols);
  input_indices.mutable_data<int64_t>(ctx.GetPlace());
  size_t temp_storage_bytes = -1;

  auto ComputeBlockSize = [](int col) {
    if (col > 512)
      return 1024;
    else if (col > 256 && col <= 512)
      return 512;
    else if (col > 128 && col <= 256)
      return 256;
    else if (col > 64 && col <= 128)
      return 128;
    else
      return 64;
  };
  int block_size = ComputeBlockSize(num_cols);

W
Wilber 已提交
962
  unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
W
wawltor 已提交
963 964 965 966 967 968 969 970 971 972 973
  // actually, int num_rows < max_grid_size
  unsigned int grid_size = num_rows < maxGridDimX
                               ? static_cast<unsigned int>(num_rows)
                               : maxGridDimX;
  // Init a index array
  InitIndex<int64_t><<<grid_size, block_size, 0, cu_stream>>>(
      input_indices.data<int64_t>(), num_rows, num_cols);

  // create iter for counting input
  cub::CountingInputIterator<int64_t> counting_iter(0);
  // segment_offset is used for move to next row
974 975
  cub::TransformInputIterator<int64_t,
                              SegmentOffsetIter,
W
wawltor 已提交
976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
                              cub::CountingInputIterator<int64_t>>
      segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

  T* sorted_values_ptr;
  int64_t* sorted_indices_ptr;

  Tensor temp_values;
  Tensor temp_indices;

  const T* input = input_tensor->data<T>();
  T* values = out_tensor->data<T>();
  int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());

  if (k == num_cols) {
    // Doing a full sort.
    sorted_values_ptr = values;
    sorted_indices_ptr = indices;
  } else {
    temp_values.Resize(dim);
    temp_indices.Resize(dim);
    sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
    sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
  }

  // Get temp storage buffer size, maybe can allocate a fixed buffer to save
  // time.
  if (largest) {
    auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
        nullptr,
        temp_storage_bytes,
        input,
        sorted_values_ptr,
        input_indices.data<int64_t>(),
        sorted_indices_ptr,
        num_cols * num_rows,
        num_rows,
        segment_offsets_t,
        segment_offsets_t + 1,
        0,
        sizeof(T) * 8,
W
wawltor 已提交
1016
        cu_stream);
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
#ifdef __HIPCC__
    if (err != hipSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "hipcub::DeviceSegmentedRadixSort::SortPairsDescending to "
                    "calculate "
                    "temp_storage_bytes, status: "
                 << hipGetErrorString(err);
      return false;
    }
#else
W
wawltor 已提交
1027 1028 1029 1030 1031 1032 1033 1034
    if (err != cudaSuccess) {
      LOG(ERROR)
          << "TopKOP failed as could not launch "
             "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
             "temp_storage_bytes, status: "
          << cudaGetErrorString(err);
      return false;
    }
1035
#endif
W
wawltor 已提交
1036
  } else {
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
    auto err =
        cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
                                                 temp_storage_bytes,
                                                 input,
                                                 sorted_values_ptr,
                                                 input_indices.data<int64_t>(),
                                                 sorted_indices_ptr,
                                                 num_cols * num_rows,
                                                 num_rows,
                                                 segment_offsets_t,
                                                 segment_offsets_t + 1,
                                                 0,
                                                 sizeof(T) * 8,
                                                 cu_stream);
1051 1052 1053 1054 1055 1056 1057 1058 1059
#ifdef __HIPCC__
    if (err != hipSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "hipcub::DeviceSegmentedRadixSort::SortPairs to calculate "
                    "temp_storage_bytes, status: "
                 << hipGetErrorString(err);
      return false;
    }
#else
W
wawltor 已提交
1060 1061 1062 1063 1064 1065 1066
    if (err != cudaSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "cub::DeviceSegmentedRadixSort::SortPairs to calculate "
                    "temp_storage_bytes, status: "
                 << cudaGetErrorString(err);
      return false;
    }
1067
#endif
W
wawltor 已提交
1068 1069 1070 1071 1072 1073
  }
  Tensor temp_storage;
  temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);

  if (largest) {
    auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
        temp_storage.data<uint8_t>(),
        temp_storage_bytes,
        input,
        sorted_values_ptr,
        input_indices.data<int64_t>(),
        sorted_indices_ptr,
        num_cols * num_rows,
        num_rows,
        segment_offsets_t,
        segment_offsets_t + 1,
        0,
        sizeof(T) * 8,
        cu_stream);
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
#ifdef __HIPCC__
    if (err != hipSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "hipcub::DeviceSegmentedRadixSort::SortPairsDescending to "
                    "sort input, "
                    "temp_storage_bytes: "
                 << temp_storage_bytes
                 << ", status: " << hipGetErrorString(err);
      return false;
    }
#else
W
wawltor 已提交
1098 1099 1100 1101 1102 1103 1104 1105 1106
    if (err != cudaSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "cub::DeviceSegmentedRadixSort::SortPairsDescending to "
                    "sort input, "
                    "temp_storage_bytes: "
                 << temp_storage_bytes
                 << ", status: " << cudaGetErrorString(err);
      return false;
    }
1107
#endif
W
wawltor 已提交
1108
  } else {
1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
    auto err =
        cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data<uint8_t>(),
                                                 temp_storage_bytes,
                                                 input,
                                                 sorted_values_ptr,
                                                 input_indices.data<int64_t>(),
                                                 sorted_indices_ptr,
                                                 num_cols * num_rows,
                                                 num_rows,
                                                 segment_offsets_t,
                                                 segment_offsets_t + 1,
                                                 0,
                                                 sizeof(T) * 8,
                                                 cu_stream);
1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133
#ifdef __HIPCC__
    if (err != hipSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "hipcub::DeviceSegmentedRadixSort::SortPairs to "
                    "sort input, "
                    "temp_storage_bytes: "
                 << temp_storage_bytes
                 << ", status: " << hipGetErrorString(err);
      return false;
    }
#else
W
wawltor 已提交
1134 1135 1136 1137 1138 1139 1140 1141 1142
    if (err != cudaSuccess) {
      LOG(ERROR) << "TopKOP failed as could not launch "
                    "cub::DeviceSegmentedRadixSort::SortPairs to "
                    "sort input, "
                    "temp_storage_bytes: "
                 << temp_storage_bytes
                 << ", status: " << cudaGetErrorString(err);
      return false;
    }
1143
#endif
W
wawltor 已提交
1144 1145 1146 1147 1148 1149
  }
  auto& dev = *ctx.eigen_device();
  if (k < num_cols) {
    // copy sliced data to output.
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
W
wuhuanzhou 已提交
1150 1151
    auto e_indices =
        framework::EigenMatrix<int64_t>::From(*indices_tensor, dim);
1152 1153
    auto e_tmp_indices = framework::EigenMatrix<int64_t>::From(
        static_cast<const Tensor>(temp_indices));
W
wawltor 已提交
1154 1155

    std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
1156
    auto dim = phi::make_ddim(odims);
W
wuhuanzhou 已提交
1157
    auto e_values = framework::EigenMatrix<T>::From(*out_tensor, dim);
1158 1159
    auto e_tmp_values =
        framework::EigenMatrix<T>::From(static_cast<const Tensor>(temp_values));
W
wawltor 已提交
1160

1161 1162 1163 1164
    EigenSlice<std::decay_t<decltype(dev)>, int64_t, 2>::Eval(
        dev, e_indices, e_tmp_indices, slice_indices, slice_sizes);
    EigenSlice<std::decay_t<decltype(dev)>, T, 2>::Eval(
        dev, e_values, e_tmp_values, slice_indices, slice_sizes);
W
wawltor 已提交
1165 1166 1167 1168 1169
  }
  return true;
}
}  // namespace operators
}  // namespace paddle