compute_primitives.h 24.4 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24
#include "paddle/phi/backends/gpu/gpu_device_function.h"
25
#include "paddle/phi/common/amp_type_traits.h"
N
niuliling123 已提交
26

27
namespace phi {
28
namespace kps {
N
niuliling123 已提交
29 30
namespace details {

31
#ifdef __HIPCC__
32
constexpr int kReduceMaxThread = 256;
33 34
constexpr int kWarpSize = 64;
#else
35
constexpr int kReduceMaxThread = 128;
36 37 38
constexpr int kWarpSize = 32;
#endif

39 40
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
41 42 43
enum ReduceMode { kGlobalMode, kLocalMode };

/**
44 45
 * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
 * memory.
46 47 48 49
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
50

51 52 53 54 55
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
56
    T temp = phi::backends::gpu::CudaShuffleDownSync(mask, val, stride);
57
    val = reducer(val, temp);
N
niuliling123 已提交
58
  }
59 60
  return val;
}
N
niuliling123 已提交
61

62 63 64 65 66 67 68 69 70 71 72
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
73
 * @brief BlockXReduce reduce along blockDim.x.
74 75 76 77 78 79 80 81
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
82 83 84 85 86
    // Bit operation can be used when kWarpSize is 32 or 64 now
    constexpr int rshift_val =
        (kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;
    block_dim_x = blockDim.x >> rshift_val;
    int lane = threadIdx.x & (kWarpSize - 1);
87
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
88
    int wid = tid >> rshift_val;
89 90 91 92 93 94 95
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
96
  }
97 98 99 100

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
101
    T temp = phi::backends::gpu::CudaShuffleDownSync(mask, val, stride);
102 103
    val = reducer(val, temp);
  }
104
  __syncthreads();
105 106 107 108 109
  if (threadIdx.x == 0) {
    shared[threadIdx.y] = val;
  }
  __syncthreads();
  return shared[threadIdx.y];
110 111 112
}

/**
113
 * @brief BlockYReduce reduce along blockDim.y.
114 115 116
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
117
  __shared__ T shared_memory[1024];
118 119 120 121 122 123 124 125 126
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
127 128
  __syncthreads();
  return shared_memory[threadIdx.x];
129 130
}

131 132 133
/**
 * @brief Swap data
 */
134 135 136 137 138 139 140 141
template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
  T t_value;
  t_value = (*first_value);
  (*first_value) = (*second_value);
  (*second_value) = t_value;
}

142 143 144
/**
 * @brief Swap data according to  monotonic_type.
 */
145 146 147 148 149 150 151 152 153
template <typename T>
__device__ __forceinline__ void Comparator(T* first_value,
                                           T* second_value,
                                           int monotonic_type) {
  if (((*first_value) > (*second_value)) == monotonic_type) {
    Swap<T>(first_value, second_value);
  }
}

154 155 156
/**
 * @brief Swap data and data index according to  monotonic_type.
 */
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,

                                                    T* second_value,
                                                    IndexType* first_index,
                                                    IndexType* second_index,
                                                    int monotonic_type) {
  if ((*first_value > (*second_value)) == monotonic_type) {
    // swap value
    Swap<T>(first_value, second_value);
    // swap index
    Swap<IndexType>(first_index, second_index);
  }
}

172 173 174 175 176 177 178 179 180 181 182 183
/**
 * @brief get the last pow of 2
 */
__device__ inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

184
}  // namespace details
N
niuliling123 已提交
185

186
/**
187
 * @brief Perform unary calculation according to OpFunc. Shape of input and
188 189 190
 * output are the same.
 *
 * @template paraments
191 192
 * InT: The data type of in.
 * OutT: The data type of out.
193 194
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
195
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
196 197
 * OpFunc: Compute functor which has an operator() as following:
 *     template <typename InT, typename OutT>
198
 *     struct XxxFunctor {
199
 *       HOSTDEVICE OutT operator()(const InT& a) const {
200 201 202
 *         return ...;
 *       }
 *     };
203 204 205 206 207
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
208
 */
209
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
210 211
__device__ __forceinline__ void ElementwiseUnary(OutT* out,
                                                 const InT* in,
212 213 214 215 216 217
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in[idx]));
  }
}
N
niuliling123 已提交
218 219

/**
220
 * @brief Binary calculation according to OpFunc. Shape of The input and output
221 222 223
 * are the same.
 *
 * @template paraments
224 225 226 227 228
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns computed by each thread.
 * NY: The number of data rows computed by each thread.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
229
 * OpFunc: Compute functor which has an operator() as following:
230
 *     template <typename InT>
231
 *     struct XxxFunctor {
232
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
233 234 235
 *         return ...;
 *       }
 *     };
236 237 238 239 240
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
241
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
242
 */
243
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
244 245
__device__ __forceinline__ void ElementwiseBinary(OutT* out,
                                                  const InT* in1,
246
                                                  const InT* in2,
N
niuliling123 已提交
247 248 249
                                                  OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
250
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
N
niuliling123 已提交
251 252 253
  }
}

254
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
255 256 257 258 259 260 261 262
__device__ __forceinline__ void ElementwiseBinary(
    OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
  }
}

N
niuliling123 已提交
263
/**
264
 * @brief Ternary calculation according to OpFunc. Shape of input and output
265 266 267
 * are the same.
 *
 * @template paraments
268 269
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
270 271
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
272
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
273
 * OpFunc: Compute functor which has an operator() as following
274
 *     template <typename InT>
275
 *     struct XxxFunctor {
276
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
277
 * const {
278 279 280
 *         return ...;
 *       }
 *     };
281 282 283 284 285 286
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
 * in3: The register pointer of third input, size is NX * NY.
287
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
288
 */
289
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
290 291
__device__ __forceinline__ void ElementwiseTernary(
    OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) {
N
niuliling123 已提交
292 293
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
294 295 296 297 298
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
  }
}

/**
299 300
 * @brief Multivariate calculation according to OpFunc. Shape of inputs and
 * output are the same.
301 302
 *
 * @template paraments
303 304
 * InT: The data type of in1, in2 and in3.
 * OutT: The data type of out.
305 306
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
307 308
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Arity: The size of ins.
309
 * OpFunc: Compute functor which has an operator() as following:
310
 *     template <typename InT>
311
 *     struct XxxFunctor {
312
 *       HOSTDEVICE InT operator()(const InT* args) const {
313 314 315
 *         return ...;
 *       }
 *     };
316 317 318
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
319 320
 * ins: A pointers of array consisting of multiple inputs.
 * compute: Compute function which was declared like OpFunc<InT>().
321
 */
322
template <typename InT, typename OutT, int NX, int NY, int Arity, class OpFunc>
323 324
__device__ __forceinline__ void ElementwiseAny(OutT* out,
                                               InT (*ins)[NX * NY],
325
                                               OpFunc compute) {
326
  InT args[Arity];
327 328 329 330 331 332
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
      args[j] = ins[j][idx];
    }
333
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
334
  }
F
Feng Xing 已提交
335
}
N
niuliling123 已提交
336 337

/**
338 339 340 341 342
 * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
 * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
 * shape is [NY, NX].
 *
 * @template paraments
343 344
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
345 346
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
347
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
348 349 350 351 352 353 354 355
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT, typename OutT>
 *     struct XxxFunctor {
 *       HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 *
356
 * @param
357 358 359 360
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * 1.
 * in2: The register pointer of second input, size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
N
niuliling123 已提交
361
 */
362
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
363 364 365 366
__device__ __forceinline__ void CycleBinary(OutT* out,
                                            const InT* in1,
                                            const InT* in2,
                                            OpFunc compute) {
N
niuliling123 已提交
367 368 369 370 371 372 373 374
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
375
}
N
niuliling123 已提交
376

377
/**
378 379
 * @brief The Reduce provides collective methods for computing a parallel
 * reduction of items partitioned across a CUDA block and intra thread. When
380 381
 * ReduceMode == kLocalMode, use shared memory to reduce between threads.When
 * ReduceMode == kGlobalMode, thread reduce along nx.
382 383 384 385 386
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
387
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
388 389 390 391 392 393 394 395 396 397 398 399 400 401
 * ReduceFunctor: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct ReduceFunctor {
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * reducer: Compute function which was declared like ReduceFunctor<InT>().
 * reduce_last_dim: if the last dim gets involved in reduction.
402
 */
403 404 405 406
template <typename T,
          int NX,
          int NY,
          class ReduceFunctor,
407
          details::ReduceMode Mode>
408 409
__device__ __forceinline__ void Reduce(T* out,
                                       const T* in,
410
                                       ReduceFunctor reducer,
411 412 413 414 415 416 417 418 419 420
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
421
        out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
422 423 424 425 426 427 428
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
429
        out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
430 431 432 433 434 435 436 437 438 439 440 441 442
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

443
/*
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
 * @brief Fill register with a constant according to OpFunc
 *
 * @template paraments
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()()
 * const {
 *         return a;
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT>().
 */
465
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
466
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
467 468 469 470 471 472
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute());
  }
}

473
/*
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
 * @brief Get ReturnsCount random data fromm compute according to state, state
 * can be curandStatePhilox4_32_10_t, hiprandStatePhilox4_32_10_t which has beed
 * initialized.
 *
 * @template paraments
 * StateType: the type of state, can be curandStatePhilox4_32_10_t or
 * hiprandStatePhilox4_32_10_t.
 * OutT: the type of out register.
 * ReturnsCount: The number of random data generated by OpFunc.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename T>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()(StateType state)
 * const {
 *         return ranomd(state);  // Returns ReturnsCount random numbers with
 * data type T
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<T>().
 */
498

499
template <typename StateType, typename OutT, int ReturnsCount, class OpFunc>
500 501 502 503 504 505 506 507 508 509
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
                                                  OpFunc compute,
                                                  StateType* state) {
  auto random_tuple = compute(state);
#pragma unroll
  for (int i = 0; i < ReturnsCount; i++) {
    out[i] = static_cast<OutT>((&random_tuple.x)[i]);
  }
}

510
/*
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
 * @brief Complete the prefix and in the block, each thread calculates 2 data,
 * the size of out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: the type of input register.
 * OutT: the type of out register.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename T>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()(T a, T b)
 * const {
 *         return a + b;
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is 2;
 * in: The register pointer of input, the size is 2;
 * compute: Compute function which was declared like OpFunc<T>().
 */
532 533

#define SHARED_SIZE_LIMIT 512
534
template <typename InT, typename OutT, class OpFunc>
535 536 537
__device__ __forceinline__ void Cumsum(OutT* out,
                                       const InT* in,
                                       OpFunc compute) {
538 539 540
  constexpr int kSize = SHARED_SIZE_LIMIT * 2 + (SHARED_SIZE_LIMIT * 2) / 32;
  __shared__ InT temp[kSize];
  int stride_size = blockDim.x;
541 542
  int tidx = threadIdx.x;
  temp[tidx + tidx / 32] = in[0];
543 544
  temp[stride_size + tidx + (stride_size + tidx) / 32] = in[1];
  for (int stride = 1; stride <= stride_size; stride *= 2) {
545 546 547
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if (index < (blockDim.x * 2)) {
548
      temp[index + index / 32] =
549
          compute(temp[index + index / 32],
550
                  temp[index - stride + (index - stride) / 32]);
551 552 553 554 555 556
    }
  }
  for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if ((index + stride) < (blockDim.x * 2)) {
557 558 559
      temp[index + stride + (stride + index) / 32] =
          compute(temp[index + stride + (stride + index) / 32],
                  temp[index + (index) / 32]);
560 561 562 563 564 565
    }
  }

  __syncthreads();
  out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
  out[1] =
566
      static_cast<OutT>(temp[tidx + stride_size + (tidx + stride_size) / 32]);
567
}
568 569 570
#undef SHARED_SIZE_LIMIT

/*
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
 * @brief Sort data in this block, each thread calculates 2 data, the size of
 * out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: the type of input register.
 * OutT: the type of out register.
 * GPU was supported.
 *
 * @param
 * out: The register pointer of out, the size is 2.
 * in: The register pointer of input, the size is 2.
 * num: The num of this block
 * monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
 * sorted in escending.
 */
586 587 588
#define SHARED_SIZE_LIMIT 1024
// each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
589
template <typename InT, typename OutT>
590 591
__device__ __forceinline__ void Sort(OutT* out,
                                     const InT* in,
592 593
                                     int num,
                                     int monotonic_type) {
594 595 596
  int upper_bound = blockDim.x;
  // update upper_bound
  upper_bound = std::min(details::GetLastPow2(num), upper_bound);
597
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
598 599 600 601 602 603
  __shared__ InT value[SHARED_SIZE_LIMIT];
  int stride_size = blockDim.x;
  // shareMem's size must larger than blockDim * 2
  // Copy value from in
  value[threadIdx.x] = in[0];
  value[threadIdx.x + stride_size] = in[1];
604
  // make bitonicSort
605
  for (int size = 2; size < upper_bound; size <<= 1) {
606 607 608 609
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
610
      details::Comparator<InT>(&value[pos], &value[pos + stride], bitonic_type);
611 612 613
    }
  }
  // last sort
614
  for (int stride = stride_size; stride > 0; stride >>= 1) {
615 616 617
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
618
    details::Comparator<InT>(&value[pos], &value[pos + stride], monotonic_type);
619 620
  }
  __syncthreads();
621 622
  out[0] = static_cast<OutT>(value[threadIdx.x]);
  out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
623 624
}

625
/*
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
 * @brief Sort data with data_index in this block, each thread calculates 2
 * data, the size of out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: The type of input register.
 * OutT: The type of out register.
 * IndexType: The type of index.
 * GPU was supported.
 *
 * @param
 * out: The register pointer of out, the size is 2.
 * out_index: The register pointer of out_index, the size is 2.
 * in: The register pointer of input, the size is 2.
 * in_index: The register pointer of in_index, the size is 2.
 * num: The num of this block.
 * monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
 * sorted in escending.
 */
644
template <typename InT, typename OutT, typename IndexType>
645 646 647 648
__device__ __forceinline__ void Sort(OutT* out,
                                     IndexType* out_index,
                                     const InT* in,
                                     IndexType* in_index,
649 650
                                     int num,
                                     int monotonic_type) {
651 652 653
  int upper_bound = blockDim.x;
  // update upper_bound
  upper_bound = std::min(details::GetLastPow2(num), upper_bound);
654
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
655 656
  __shared__ InT value[SHARED_SIZE_LIMIT];
  // shareMem's size must larger than blockDim * 2
657
  __shared__ IndexType index[SHARED_SIZE_LIMIT];
658 659 660 661
  // Copy value and index from in and in_index
  int stride_size = blockDim.x;
  value[threadIdx.x] = in[0];
  value[threadIdx.x + stride_size] = in[1];
662
  // index
663 664
  index[threadIdx.x] = in_index[0];
  index[threadIdx.x + stride_size] = in_index[1];
665
  // make bitonicSort
666
  for (int size = 2; size < upper_bound; size <<= 1) {
667 668 669 670
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
671 672 673 674 675
      details::ComparatorWithIndex<InT, IndexType>(&value[pos],
                                                   &value[pos + stride],
                                                   &index[pos],
                                                   &index[pos + stride],
                                                   bitonic_type);
676 677 678
    }
  }

679
  for (int stride = stride_size; stride > 0; stride >>= 1) {
680 681 682
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
683 684 685 686 687
    details::ComparatorWithIndex<InT, IndexType>(&value[pos],
                                                 &value[pos + stride],
                                                 &index[pos],
                                                 &index[pos + stride],
                                                 monotonic_type);
688 689 690
  }

  __syncthreads();
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
  out[0] = static_cast<OutT>(value[threadIdx.x]);
  out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
  out_index[0] = index[threadIdx.x];
  out_index[1] = index[threadIdx.x + stride_size];
}

template <typename T1, typename T2, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorTernary(
    OutT* out, const T1* in1, const T2* in2, OpFunc func, int num) {
  func(out, in1, in2, num);
}

template <typename InT, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorBinary(OutT* out,
                                               const InT* in,
                                               OpFunc func,
                                               int num) {
  func(out, in, num);
709 710
}

711
}  // namespace kps
712
}  // namespace phi