compute_primitives.h 24.4 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
25
#include "paddle/phi/common/float16.h"
N
niuliling123 已提交
26

27
namespace phi {
28
namespace kps {
N
niuliling123 已提交
29 30
namespace details {

31
#ifdef __HIPCC__
32
constexpr int kReduceMaxThread = 256;
33 34
constexpr int kWarpSize = 64;
#else
35
constexpr int kReduceMaxThread = 128;
36 37 38
constexpr int kWarpSize = 32;
#endif

39 40
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
41 42
enum ReduceMode { kGlobalMode, kLocalMode };

N
niuliling123 已提交
43 44 45 46 47 48 49
template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};

template <>
50
class MPTypeTrait<phi::dtype::float16> {
N
niuliling123 已提交
51 52 53 54
 public:
  using Type = float;
};

55
/**
56 57
 * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
 * memory.
58 59 60 61
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
62

63 64 65 66 67 68 69
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
N
niuliling123 已提交
70
  }
71 72
  return val;
}
N
niuliling123 已提交
73

74 75 76 77 78 79 80 81 82 83 84
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
85
 * @brief BlockXReduce reduce along blockDim.x.
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
105
  }
106 107 108 109 110 111 112

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
113 114 115 116 117
  if (threadIdx.x == 0) {
    shared[threadIdx.y] = val;
  }
  __syncthreads();
  return shared[threadIdx.y];
118 119 120
}

/**
121
 * @brief BlockYReduce reduce along blockDim.y.
122 123 124
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
125
  __shared__ T shared_memory[1024];
126 127 128 129 130 131 132 133 134
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
135 136
  __syncthreads();
  return shared_memory[threadIdx.x];
137 138
}

139 140 141
/**
 * @brief Swap data
 */
142 143 144 145 146 147 148 149
template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
  T t_value;
  t_value = (*first_value);
  (*first_value) = (*second_value);
  (*second_value) = t_value;
}

150 151 152
/**
 * @brief Swap data according to  monotonic_type.
 */
153 154 155 156 157 158 159 160 161
template <typename T>
__device__ __forceinline__ void Comparator(T* first_value,
                                           T* second_value,
                                           int monotonic_type) {
  if (((*first_value) > (*second_value)) == monotonic_type) {
    Swap<T>(first_value, second_value);
  }
}

162 163 164
/**
 * @brief Swap data and data index according to  monotonic_type.
 */
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,

                                                    T* second_value,
                                                    IndexType* first_index,
                                                    IndexType* second_index,
                                                    int monotonic_type) {
  if ((*first_value > (*second_value)) == monotonic_type) {
    // swap value
    Swap<T>(first_value, second_value);
    // swap index
    Swap<IndexType>(first_index, second_index);
  }
}

180 181 182 183 184 185 186 187 188 189 190 191
/**
 * @brief get the last pow of 2
 */
__device__ inline int GetLastPow2(int n) {
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8);
  n |= (n >> 16);
  return std::max(1, n - (n >> 1));
}

192
}  // namespace details
N
niuliling123 已提交
193

194
/**
195
 * @brief Perform unary calculation according to OpFunc. Shape of input and
196 197 198
 * output are the same.
 *
 * @template paraments
199 200
 * InT: The data type of in.
 * OutT: The data type of out.
201 202
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
203
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
204 205
 * OpFunc: Compute functor which has an operator() as following:
 *     template <typename InT, typename OutT>
206
 *     struct XxxFunctor {
207
 *       HOSTDEVICE OutT operator()(const InT& a) const {
208 209 210
 *         return ...;
 *       }
 *     };
211 212 213 214 215
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
216
 */
217
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
218 219
__device__ __forceinline__ void ElementwiseUnary(OutT* out,
                                                 const InT* in,
220 221 222 223 224 225
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in[idx]));
  }
}
N
niuliling123 已提交
226 227

/**
228
 * @brief Binary calculation according to OpFunc. Shape of The input and output
229 230 231
 * are the same.
 *
 * @template paraments
232 233 234 235 236
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns computed by each thread.
 * NY: The number of data rows computed by each thread.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
237
 * OpFunc: Compute functor which has an operator() as following:
238
 *     template <typename InT>
239
 *     struct XxxFunctor {
240
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
241 242 243
 *         return ...;
 *       }
 *     };
244 245 246 247 248
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
249
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
250
 */
251
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
252 253
__device__ __forceinline__ void ElementwiseBinary(OutT* out,
                                                  const InT* in1,
254
                                                  const InT* in2,
N
niuliling123 已提交
255 256 257
                                                  OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
258
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
N
niuliling123 已提交
259 260 261
  }
}

262
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
263 264 265 266 267 268 269 270
__device__ __forceinline__ void ElementwiseBinary(
    OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
  }
}

N
niuliling123 已提交
271
/**
272
 * @brief Ternary calculation according to OpFunc. Shape of input and output
273 274 275
 * are the same.
 *
 * @template paraments
276 277
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
278 279
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
280
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
281
 * OpFunc: Compute functor which has an operator() as following
282
 *     template <typename InT>
283
 *     struct XxxFunctor {
284
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
285
 * const {
286 287 288
 *         return ...;
 *       }
 *     };
289 290 291 292 293 294
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
 * in3: The register pointer of third input, size is NX * NY.
295
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
296
 */
297
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
298 299
__device__ __forceinline__ void ElementwiseTernary(
    OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) {
N
niuliling123 已提交
300 301
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
302 303 304 305 306
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
  }
}

/**
307 308
 * @brief Multivariate calculation according to OpFunc. Shape of inputs and
 * output are the same.
309 310
 *
 * @template paraments
311 312
 * InT: The data type of in1, in2 and in3.
 * OutT: The data type of out.
313 314
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
315 316
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Arity: The size of ins.
317
 * OpFunc: Compute functor which has an operator() as following:
318
 *     template <typename InT>
319
 *     struct XxxFunctor {
320
 *       HOSTDEVICE InT operator()(const InT* args) const {
321 322 323
 *         return ...;
 *       }
 *     };
324 325 326
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
327 328
 * ins: A pointers of array consisting of multiple inputs.
 * compute: Compute function which was declared like OpFunc<InT>().
329
 */
330
template <typename InT, typename OutT, int NX, int NY, int Arity, class OpFunc>
331 332
__device__ __forceinline__ void ElementwiseAny(OutT* out,
                                               InT (*ins)[NX * NY],
333
                                               OpFunc compute) {
334
  InT args[Arity];
335 336 337 338 339 340
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
      args[j] = ins[j][idx];
    }
341
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
342
  }
F
Feng Xing 已提交
343
}
N
niuliling123 已提交
344 345

/**
346 347 348 349 350
 * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
 * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
 * shape is [NY, NX].
 *
 * @template paraments
351 352
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
353 354
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
355
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
356 357 358 359 360 361 362 363
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT, typename OutT>
 *     struct XxxFunctor {
 *       HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 *
364
 * @param
365 366 367 368
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * 1.
 * in2: The register pointer of second input, size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
N
niuliling123 已提交
369
 */
370
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
371 372 373 374
__device__ __forceinline__ void CycleBinary(OutT* out,
                                            const InT* in1,
                                            const InT* in2,
                                            OpFunc compute) {
N
niuliling123 已提交
375 376 377 378 379 380 381 382
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
383
}
N
niuliling123 已提交
384

385
/**
386 387 388 389 390 391 392 393 394
 * @brief The Reduce provides collective methods for computing a parallel
 * reduction of items partitioned across a CUDA block and intra thread. When
 * ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
 * kGlobalMode, use shared memory to reduce between threads.
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
395
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
396 397 398 399 400 401 402 403 404 405 406 407 408 409
 * ReduceFunctor: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct ReduceFunctor {
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * reducer: Compute function which was declared like ReduceFunctor<InT>().
 * reduce_last_dim: if the last dim gets involved in reduction.
410
 */
411 412 413 414
template <typename T,
          int NX,
          int NY,
          class ReduceFunctor,
415
          details::ReduceMode Mode>
416 417
__device__ __forceinline__ void Reduce(T* out,
                                       const T* in,
418
                                       ReduceFunctor reducer,
419 420 421 422 423 424 425 426 427 428
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
429
        out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
430 431 432 433 434 435 436
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
437
        out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
438 439 440 441 442 443 444 445 446 447 448 449 450
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

451
/*
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
 * @brief Fill register with a constant according to OpFunc
 *
 * @template paraments
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()()
 * const {
 *         return a;
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT>().
 */
473
template <typename InT, typename OutT, int NX, int NY, class OpFunc>
474
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
475 476 477 478 479 480
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute());
  }
}

481
/*
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
 * @brief Get ReturnsCount random data fromm compute according to state, state
 * can be curandStatePhilox4_32_10_t, hiprandStatePhilox4_32_10_t which has beed
 * initialized.
 *
 * @template paraments
 * StateType: the type of state, can be curandStatePhilox4_32_10_t or
 * hiprandStatePhilox4_32_10_t.
 * OutT: the type of out register.
 * ReturnsCount: The number of random data generated by OpFunc.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename T>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()(StateType state)
 * const {
 *         return ranomd(state);  // Returns ReturnsCount random numbers with
 * data type T
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<T>().
 */
506

507
template <typename StateType, typename OutT, int ReturnsCount, class OpFunc>
508 509 510 511 512 513 514 515 516 517
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
                                                  OpFunc compute,
                                                  StateType* state) {
  auto random_tuple = compute(state);
#pragma unroll
  for (int i = 0; i < ReturnsCount; i++) {
    out[i] = static_cast<OutT>((&random_tuple.x)[i]);
  }
}

518
/*
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
 * @brief Complete the prefix and in the block, each thread calculates 2 data,
 * the size of out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: the type of input register.
 * OutT: the type of out register.
 * GPU was supported.
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename T>
 *     struct XxxFunctor {
 *       HOSTDEVICE InT operator()(T a, T b)
 * const {
 *         return a + b;
 *       }
 *     };
 *
 * @param
 * out: The register pointer of out, the size is 2;
 * in: The register pointer of input, the size is 2;
 * compute: Compute function which was declared like OpFunc<T>().
 */
540 541

#define SHARED_SIZE_LIMIT 512
542
template <typename InT, typename OutT, class OpFunc>
543 544 545
__device__ __forceinline__ void Cumsum(OutT* out,
                                       const InT* in,
                                       OpFunc compute) {
546 547 548
  constexpr int kSize = SHARED_SIZE_LIMIT * 2 + (SHARED_SIZE_LIMIT * 2) / 32;
  __shared__ InT temp[kSize];
  int stride_size = blockDim.x;
549 550
  int tidx = threadIdx.x;
  temp[tidx + tidx / 32] = in[0];
551 552
  temp[stride_size + tidx + (stride_size + tidx) / 32] = in[1];
  for (int stride = 1; stride <= stride_size; stride *= 2) {
553 554 555
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if (index < (blockDim.x * 2)) {
556
      temp[index + index / 32] =
557
          compute(temp[index + index / 32],
558
                  temp[index - stride + (index - stride) / 32]);
559 560 561 562 563 564
    }
  }
  for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if ((index + stride) < (blockDim.x * 2)) {
565 566 567
      temp[index + stride + (stride + index) / 32] =
          compute(temp[index + stride + (stride + index) / 32],
                  temp[index + (index) / 32]);
568 569 570 571 572 573
    }
  }

  __syncthreads();
  out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
  out[1] =
574
      static_cast<OutT>(temp[tidx + stride_size + (tidx + stride_size) / 32]);
575
}
576 577 578
#undef SHARED_SIZE_LIMIT

/*
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593
 * @brief Sort data in this block, each thread calculates 2 data, the size of
 * out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: the type of input register.
 * OutT: the type of out register.
 * GPU was supported.
 *
 * @param
 * out: The register pointer of out, the size is 2.
 * in: The register pointer of input, the size is 2.
 * num: The num of this block
 * monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
 * sorted in escending.
 */
594 595 596
#define SHARED_SIZE_LIMIT 1024
// each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
597
template <typename InT, typename OutT>
598 599
__device__ __forceinline__ void Sort(OutT* out,
                                     const InT* in,
600 601
                                     int num,
                                     int monotonic_type) {
602 603 604
  int upper_bound = blockDim.x;
  // update upper_bound
  upper_bound = std::min(details::GetLastPow2(num), upper_bound);
605
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
606 607 608 609 610 611
  __shared__ InT value[SHARED_SIZE_LIMIT];
  int stride_size = blockDim.x;
  // shareMem's size must larger than blockDim * 2
  // Copy value from in
  value[threadIdx.x] = in[0];
  value[threadIdx.x + stride_size] = in[1];
612
  // make bitonicSort
613
  for (int size = 2; size < upper_bound; size <<= 1) {
614 615 616 617
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
618
      details::Comparator<InT>(&value[pos], &value[pos + stride], bitonic_type);
619 620 621
    }
  }
  // last sort
622
  for (int stride = stride_size; stride > 0; stride >>= 1) {
623 624 625
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
626
    details::Comparator<InT>(&value[pos], &value[pos + stride], monotonic_type);
627 628
  }
  __syncthreads();
629 630
  out[0] = static_cast<OutT>(value[threadIdx.x]);
  out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
631 632
}

633
/*
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
 * @brief Sort data with data_index in this block, each thread calculates 2
 * data, the size of out and in is 2, and BlockDim.x must be less then 512.
 *
 * @template paraments
 * InT: The type of input register.
 * OutT: The type of out register.
 * IndexType: The type of index.
 * GPU was supported.
 *
 * @param
 * out: The register pointer of out, the size is 2.
 * out_index: The register pointer of out_index, the size is 2.
 * in: The register pointer of input, the size is 2.
 * in_index: The register pointer of in_index, the size is 2.
 * num: The num of this block.
 * monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
 * sorted in escending.
 */
652
template <typename InT, typename OutT, typename IndexType>
653 654 655 656
__device__ __forceinline__ void Sort(OutT* out,
                                     IndexType* out_index,
                                     const InT* in,
                                     IndexType* in_index,
657 658
                                     int num,
                                     int monotonic_type) {
659 660 661
  int upper_bound = blockDim.x;
  // update upper_bound
  upper_bound = std::min(details::GetLastPow2(num), upper_bound);
662
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
663 664
  __shared__ InT value[SHARED_SIZE_LIMIT];
  // shareMem's size must larger than blockDim * 2
665
  __shared__ IndexType index[SHARED_SIZE_LIMIT];
666 667 668 669
  // Copy value and index from in and in_index
  int stride_size = blockDim.x;
  value[threadIdx.x] = in[0];
  value[threadIdx.x + stride_size] = in[1];
670
  // index
671 672
  index[threadIdx.x] = in_index[0];
  index[threadIdx.x + stride_size] = in_index[1];
673
  // make bitonicSort
674
  for (int size = 2; size < upper_bound; size <<= 1) {
675 676 677 678
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
679 680 681 682 683
      details::ComparatorWithIndex<InT, IndexType>(&value[pos],
                                                   &value[pos + stride],
                                                   &index[pos],
                                                   &index[pos + stride],
                                                   bitonic_type);
684 685 686
    }
  }

687
  for (int stride = stride_size; stride > 0; stride >>= 1) {
688 689 690
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
691 692 693 694 695
    details::ComparatorWithIndex<InT, IndexType>(&value[pos],
                                                 &value[pos + stride],
                                                 &index[pos],
                                                 &index[pos + stride],
                                                 monotonic_type);
696 697 698
  }

  __syncthreads();
699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
  out[0] = static_cast<OutT>(value[threadIdx.x]);
  out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
  out_index[0] = index[threadIdx.x];
  out_index[1] = index[threadIdx.x + stride_size];
}

template <typename T1, typename T2, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorTernary(
    OutT* out, const T1* in1, const T2* in2, OpFunc func, int num) {
  func(out, in1, in2, num);
}

template <typename InT, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorBinary(OutT* out,
                                               const InT* in,
                                               OpFunc func,
                                               int num) {
  func(out, in, num);
717 718
}

719
}  // namespace kps
720
}  // namespace phi