compute_primitives.h 20.6 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
25
#include "paddle/phi/common/float16.h"
N
niuliling123 已提交
26

27
namespace phi {
28
namespace kps {
N
niuliling123 已提交
29 30
namespace details {

31
#ifdef __HIPCC__
32
constexpr int kReduceMaxThread = 256;
33 34
constexpr int kWarpSize = 64;
#else
35
constexpr int kReduceMaxThread = 128;
36 37 38
constexpr int kWarpSize = 32;
#endif

39 40
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
41 42
enum ReduceMode { kGlobalMode, kLocalMode };

N
niuliling123 已提交
43 44 45 46 47 48 49
template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};

template <>
50
class MPTypeTrait<phi::dtype::float16> {
N
niuliling123 已提交
51 52 53 54
 public:
  using Type = float;
};

55
/**
56 57
 * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
 * memory.
58 59 60 61
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
62

63 64 65 66 67 68 69
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
N
niuliling123 已提交
70
  }
71 72
  return val;
}
N
niuliling123 已提交
73

74 75 76 77 78 79 80 81 82 83 84
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
85
 * @brief BlockXReduce reduce along blockDim.x.
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
105
  }
106 107 108 109 110 111 112 113 114 115 116

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/**
117
 * @brief BlockYReduce reduce along blockDim.y.
118 119 120
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
121
  __shared__ T shared_memory[1024];
122 123 124 125 126 127 128 129 130
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
131 132
  __syncthreads();
  return shared_memory[threadIdx.x];
133 134
}

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
// Swap data
template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
  T t_value;
  t_value = (*first_value);
  (*first_value) = (*second_value);
  (*second_value) = t_value;
}

// swap with monotonic_type
template <typename T>
__device__ __forceinline__ void Comparator(T* first_value,
                                           T* second_value,
                                           int monotonic_type) {
  if (((*first_value) > (*second_value)) == monotonic_type) {
    Swap<T>(first_value, second_value);
  }
}

template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,

                                                    T* second_value,
                                                    IndexType* first_index,
                                                    IndexType* second_index,
                                                    int monotonic_type) {
  if ((*first_value > (*second_value)) == monotonic_type) {
    // swap value
    Swap<T>(first_value, second_value);
    // swap index
    Swap<IndexType>(first_index, second_index);
  }
}

169
}  // namespace details
N
niuliling123 已提交
170

171
/**
172
 * @brief Perform unary calculation according to OpFunc. Shape of input and
173 174 175
 * output are the same.
 *
 * @template paraments
176 177
 * InT: The data type of in.
 * OutT: The data type of out.
178 179 180
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
181
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
182 183
 * OpFunc: Compute functor which has an operator() as following:
 *     template <typename InT, typename OutT>
184
 *     struct XxxFunctor {
185
 *       HOSTDEVICE OutT operator()(const InT& a) const {
186 187 188
 *         return ...;
 *       }
 *     };
189 190 191 192 193
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
194
 */
195 196 197 198 199
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
200
          class OpFunc>
201 202
__device__ __forceinline__ void ElementwiseUnary(OutT* out,
                                                 const InT* in,
203 204 205 206 207 208
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in[idx]));
  }
}
N
niuliling123 已提交
209 210

/**
211
 * @brief Binary calculation according to OpFunc. Shape of The input and output
212 213 214
 * are the same.
 *
 * @template paraments
215 216 217 218
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns computed by each thread.
 * NY: The number of data rows computed by each thread.
219
 * BlockSize: Identifies the current device thread index method. For GPU,
220
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
221
 * OpFunc: Compute functor which has an operator() as following:
222
 *     template <typename InT>
223
 *     struct XxxFunctor {
224
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
225 226 227
 *         return ...;
 *       }
 *     };
228 229 230 231 232
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
233
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
234
 */
235 236 237 238 239
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
240
          class OpFunc>
241 242
__device__ __forceinline__ void ElementwiseBinary(OutT* out,
                                                  const InT* in1,
243
                                                  const InT* in2,
N
niuliling123 已提交
244 245 246
                                                  OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
247
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
N
niuliling123 已提交
248 249 250 251
  }
}

/**
252
 * @brief Ternary calculation according to OpFunc. Shape of input and output
253 254 255
 * are the same.
 *
 * @template paraments
256 257
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
258 259 260
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
261
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
262
 * OpFunc: Compute functor which has an operator() as following
263
 *     template <typename InT>
264
 *     struct XxxFunctor {
265
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
266
 * const {
267 268 269
 *         return ...;
 *       }
 *     };
270 271 272 273 274 275
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
 * in3: The register pointer of third input, size is NX * NY.
276
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
277
 */
278 279 280 281 282
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
283
          class OpFunc>
284 285
__device__ __forceinline__ void ElementwiseTernary(
    OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) {
N
niuliling123 已提交
286 287
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
288 289 290 291 292
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
  }
}

/**
293 294
 * @brief Multivariate calculation according to OpFunc. Shape of inputs and
 * output are the same.
295 296
 *
 * @template paraments
297 298
 * InT: The data type of in1, in2 and in3.
 * OutT: The data type of out.
299 300 301
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
302 303
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Arity: The size of ins.
304
 * OpFunc: Compute functor which has an operator() as following:
305
 *     template <typename InT>
306
 *     struct XxxFunctor {
307
 *       HOSTDEVICE InT operator()(const InT* args) const {
308 309 310
 *         return ...;
 *       }
 *     };
311 312 313
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
314 315
 * ins: A pointers of array consisting of multiple inputs.
 * compute: Compute function which was declared like OpFunc<InT>().
316
 */
317 318 319 320 321 322
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
          int Arity,
323
          class OpFunc>
324 325
__device__ __forceinline__ void ElementwiseAny(OutT* out,
                                               InT (*ins)[NX * NY],
326
                                               OpFunc compute) {
327
  InT args[Arity];
328 329 330 331 332 333
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
      args[j] = ins[j][idx];
    }
334
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
335
  }
F
Feng Xing 已提交
336
}
N
niuliling123 已提交
337 338

/**
339 340 341 342 343
 * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
 * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
 * shape is [NY, NX].
 *
 * @template paraments
344 345
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
346 347 348
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
349
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
350 351 352 353 354 355 356 357
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT, typename OutT>
 *     struct XxxFunctor {
 *       HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 *
358
 * @param
359 360 361 362
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * 1.
 * in2: The register pointer of second input, size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
N
niuliling123 已提交
363
 */
364 365 366 367 368
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
369
          class OpFunc>
370 371 372 373
__device__ __forceinline__ void CycleBinary(OutT* out,
                                            const InT* in1,
                                            const InT* in2,
                                            OpFunc compute) {
N
niuliling123 已提交
374 375 376 377 378 379 380 381
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
382
}
N
niuliling123 已提交
383

384
/**
385 386 387 388 389 390 391 392 393 394
 * @brief The Reduce provides collective methods for computing a parallel
 * reduction of items partitioned across a CUDA block and intra thread. When
 * ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
 * kGlobalMode, use shared memory to reduce between threads.
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
395
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
396 397 398 399 400 401 402 403 404 405 406 407 408 409
 * ReduceFunctor: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct ReduceFunctor {
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * reducer: Compute function which was declared like ReduceFunctor<InT>().
 * reduce_last_dim: if the last dim gets involved in reduction.
410
 */
411 412 413 414 415
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          class ReduceFunctor,
416
          details::ReduceMode Mode>
417 418
__device__ __forceinline__ void Reduce(T* out,
                                       const T* in,
419
                                       ReduceFunctor reducer,
420 421 422 423 424 425 426 427 428 429
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
430
        out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
431 432 433 434 435 436 437
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
438
        out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
439 440 441 442 443 444 445 446 447 448 449 450 451
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

452 453 454 455 456 457
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
          class OpFunc>
458
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
459 460 461 462 463 464
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute());
  }
}

465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
template <typename StateType,
          typename OutT,
          int ReturnsCount,
          int BlockSize,
          class OpFunc>
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
                                                  OpFunc compute,
                                                  StateType* state) {
  auto random_tuple = compute(state);
#pragma unroll
  for (int i = 0; i < ReturnsCount; i++) {
    out[i] = static_cast<OutT>((&random_tuple.x)[i]);
  }
}

// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
          class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out,
                                       const InT* in,
                                       OpFunc compute) {
  __shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32];
  int tidx = threadIdx.x;
  temp[tidx + tidx / 32] = in[0];
  temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1];
  for (int stride = 1; stride <= blockDim.x; stride *= 2) {
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if (index < (blockDim.x * 2)) {
      temp[index + index / 32] += temp[index - stride + (index - stride) / 32];
    }
  }
  for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
    __syncthreads();
    int index = (tidx + 1) * 2 * stride - 1;
    if ((index + stride) < (blockDim.x * 2)) {
      temp[index + stride + (stride + index) / 32] +=
          temp[index + (index) / 32];
    }
  }

  __syncthreads();
  out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
  out[1] =
      static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}

518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
#define SHARED_SIZE_LIMIT \
  1024  // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
        // larger than blockDim.x * 2
// if monotonic_type = 1 then increase
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
// == 1 the increase
template <typename T>
__device__ __forceinline__ void Sort(T* dst,
                                     const T* src_data,
                                     int num,
                                     int monotonic_type) {
  // todo: set  num = Pow2(num)
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
  __shared__ T value[SHARED_SIZE_LIMIT];  // shareMem's size must larger than
                                          // blockDim * 2
  // Copy value and index from src and src_index
  value[threadIdx.x] = src_data[0];
  value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
  // make bitonicSort
  for (int size = 2; size < num; size <<= 1) {
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
      details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type);
    }
  }
  // last sort
  for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
    details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type);
  }
  __syncthreads();
  dst[0] = value[threadIdx.x];
  dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}

template <typename T, typename IndexType>
__device__ __forceinline__ void Sort(T* dst,
                                     IndexType* dst_index,
                                     const T* src_data,
                                     IndexType* src_index,
                                     int num,
                                     int monotonic_type) {
  // todo: set  num = Pow2(num)
  // shareMem for value and index  num must smaller than SHARED_SIZE_LIMIT / 2
  __shared__ T value[SHARED_SIZE_LIMIT];  // shareMem's size must larger than
                                          // blockDim * 2
  __shared__ IndexType index[SHARED_SIZE_LIMIT];
  // Copy value and index from src and src_index
  value[threadIdx.x] = src_data[0];
  value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
  // index
  index[threadIdx.x] = src_index[0];
  index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1];
  // make bitonicSort
  for (int size = 2; size < num; size <<= 1) {
    int bitonic_type = (threadIdx.x & (size / 2)) != 0;
    for (int stride = size / 2; stride > 0; stride >>= 1) {
      __syncthreads();
      int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
      details::ComparatorWithIndex<T, IndexType>(&value[pos],
                                                 &value[pos + stride],
                                                 &index[pos],
                                                 &index[pos + stride],
                                                 bitonic_type);
    }
  }

  for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
    // last sort when monotonic_type = 1 then increase
    details::ComparatorWithIndex<T, IndexType>(&value[pos],
                                               &value[pos + stride],
                                               &index[pos],
                                               &index[pos + stride],
                                               monotonic_type);
  }

  __syncthreads();
  dst[0] = value[threadIdx.x];
  dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
  dst_index[0] = index[threadIdx.x];
  dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}

607
}  // namespace kps
608
}  // namespace phi