datamover_primitives.h 24.8 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#ifdef PADDLE_WITH_CUDA
N
niuliling123 已提交
17 18
#include <cuda.h>
#include <cuda_fp16.h>
19 20 21 22
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
23
#include "paddle/phi/core/ddim.h"
F
Feng Xing 已提交
24

25
namespace phi {
26
namespace kps {
N
niuliling123 已提交
27 28 29 30 31 32 33 34
namespace details {

#define INT_BITS 32

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};
35 36 37 38 39 40 41
/**
 * Fast division : Replace division in CUDA with multiplication to improve
 * kernel performance.
 * 1. Complete the division calculation on the CPU, and record the calculation
 * results by using the divider and shift_val.
 * 2. Set the divisor on the GPU through Div() to complete the calculation.
 */
N
niuliling123 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
struct FastDivMod {
  // 1st value represents the result of input number divides by recorded divisor
  // 2nd value represents the result of input number modulo by recorded divisor
  using DivModT = VectorType<uint32_t, 2>;

  FastDivMod() {}
  HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
    static_assert(sizeof(unsigned int) == 4,
                  "Only Support 32-bit unsigned int.");

    for (shift_val = 0; shift_val < INT_BITS; ++shift_val) {
      auto shift_limit = 1 << shift_val;
      if (shift_limit >= divisor) break;
    }
    uint64_t long_one = 1;
    uint64_t temp_div =
        ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) /
            divisor +
        1;
    multiplier = temp_div;
  }

  __device__ __forceinline__ uint32_t Div(uint32_t n) const {
    uint32_t t = __umulhi(n, multiplier);
    return (t + n) >> shift_val;
  }

  __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
    uint32_t q = Div(n);
    DivModT result = {q, n - q * divisor};
    return result;
  }

  int32_t divisor;
  int32_t shift_val;
  uint32_t multiplier;
};

80 81 82 83 84
/**
 * Configuration of broadcast. Calculate the input data index according to the
 * index of the output data. if input or output shape is [dim0, dim1] then dims
 * must be [dim1, dim0].
 */
N
niuliling123 已提交
85 86 87
template <int kDims>
struct BroadcastConfig {
  FastDivMod divmoders[kDims];
88
  uint32_t strides[phi::DDim::kMaxRank];
N
niuliling123 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
                             int dim_size) {
    std::vector<uint32_t> strides_in;
    std::vector<FastDivMod> divmoders_in;
    // for divmoders
    divmoders_in.resize(dim_size);
    for (int i = 0; i < dim_size; ++i) {
      divmoders_in[i] = FastDivMod(out_dims[i]);
    }
    // for strides
    strides_in.resize(dim_size, 1);
    for (int i = 0; i < dim_size; ++i) {
      strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i];
105 106 107 108 109 110
      strides_in[i] = (i != 0 && strides_in[i] != 0)
                          ? std::accumulate(in_dims.begin(),
                                            in_dims.begin() + i,
                                            1,
                                            std::multiplies<int64_t>())
                          : strides_in[i];
N
niuliling123 已提交
111 112 113 114 115 116 117 118 119 120
    }

    memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
    memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
  }
};

#undef INT_BITS
}  // namespace details

121
/**
122 123
 * @brief Read 2D data from global memory to register according to Tx type, and
 * store it as Ty type into register.
124 125 126 127 128 129 130
 *
 * @template paraments
 * Tx: The type of data stored in the global memory.
 * Ty: The type of data that needs to be stored in registers.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
131
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
132 133 134 135 136
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
137
 * @param:
138
 * dst: The register pointer of the thread, the size is NX * NY.
139 140 141 142 143 144 145
 * src: The data pointer of the current block.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
146
 */
147 148 149 150 151
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int BlockSize,
152
          bool IsBoundary = false>
153 154 155 156 157 158
__device__ __forceinline__ void ReadData(Ty* dst,
                                         const Tx* __restrict__ src,
                                         int size_nx,
                                         int size_ny,
                                         int stride_nx,
                                         int stride_ny) {
159
  int thread_offset = threadIdx.x;
160
  int left_size_nx = size_nx - thread_offset;
161 162 163 164

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
165 166
      if (left_size_nx > 0) {
        dst[0] = static_cast<Ty>(src[thread_offset]);
167 168
      }
    } else {
169
      dst[0] = static_cast<Ty>(src[thread_offset]);
170 171
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
N
niuliling123 已提交
172
#pragma unroll
173 174
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
175
        if (idy * stride_ny >= size_ny) {
176 177 178
          break;
        }
      }
179
      dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
180 181 182 183 184
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
185
        if (idx * stride_nx >= left_size_nx) {
186 187 188
          break;
        }
      }
189
      dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
190 191 192 193 194
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
195
        if (idx * stride_nx >= left_size_nx) {
196 197 198 199 200 201
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
202
          if (idy * stride_ny >= size_ny) {
203 204 205
            break;
          }
        }
206 207
        dst[idy * NX + idx] = static_cast<Ty>(
            src[thread_offset + idx * stride_nx + idy * stride_ny]);
208
      }
N
niuliling123 已提交
209 210 211 212
    }
  }
}

213 214 215 216 217 218 219 220 221 222 223
/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: Initial value.
 */
224 225 226 227 228 229 230 231
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

232 233 234 235 236 237 238 239 240 241 242 243
/**
 * The difference from the above function is that
 * it supports different data types of inputs.
 */
template <typename T, typename ArgsT, int Index, int NX>
__device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    std::get<Index>(dst[i]) = init_data;
  }
}

244
/**
245
 * @brief Read 1D data from global memory to register. When IsBoundary = true
246 247 248 249
 * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
 * improve memory access efficiency.
 *
 * @template paraments
250 251 252
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
253
 * BlockSize: Identifies the current device thread index method. For GPU,
254
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
255 256
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
257
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
258 259
 * crossing the boundary.
 *
260
 * @param:
261
 * dst: The register pointer of the thread, the size is NX * NY.
262
 * src: The data pointer of the current block.
263
 * size: The current block needs to load size data continuously.
264 265
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
266 267
__device__ __forceinline__ void ReadData(T* dst,
                                         const T* __restrict__ src,
268 269
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
270
    int thread_offset = threadIdx.x * NX;
271 272
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
273 274
      if (idx + thread_offset < num) {
        dst[idx] = src[thread_offset + idx];
275 276 277
      }
    }
  } else {  // blockDim,x * NX < num
278 279
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
280
    int thread_offset = threadIdx.x * kVectorsPerThread;
N
niuliling123 已提交
281

282
    using VecType = details::VectorType<T, kVectorSize>;
N
niuliling123 已提交
283
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
284 285
    VecType vec_temp[kVectorsPerThread];

N
niuliling123 已提交
286
#pragma unroll
287
    for (int i = 0; i < kVectorsPerThread; ++i) {
288
      vec_temp[i] = vec_input[thread_offset + i];
289 290 291 292
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
N
niuliling123 已提交
293 294 295 296
    }
  }
}

297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
/**
 * @brief Read 1D data from global memory to register. The difference
 * from the above function is that it supports different data types of inputs.
 */
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          typename ArgsT,
          int Index,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadData(ArgsT* dst,
                                         const T* __restrict__ src,
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
    int thread_offset = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        std::get<Index>(dst[idx]) = src[thread_offset + idx];
      }
    }
  } else {  // blockDim,x * NX < num
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
    int thread_offset = threadIdx.x * kVectorsPerThread;

    using VecType = details::VectorType<T, kVectorSize>;
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
    VecType vec_temp[kVectorsPerThread];

#pragma unroll
    for (int i = 0; i < kVectorsPerThread; ++i) {
      vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        std::get<Index>(dst[idx]) = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
    }
  }
}

339
/**
340
 * @brief Read 2D data from global memory to registers with broadcast form.
341 342 343 344 345 346
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
347
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
348 349 350
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
351
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
352 353
 * crossing the boundary.
 *
N
niuliling123 已提交
354
 * @param:
355
 * dst: The register pointer of the thread, the size is NX * NY.
356 357
 * src: The original input data pointer of this kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
358
 * config: Calculation configuration of broadcast. It is used to calculate the
359
 * coordinate mapping relationship between output data and input data.
360
 * total_num_output: Total number of original output.
361 362
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
N
niuliling123 已提交
363
 */
364 365 366 367 368
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          int Rank,
369
          bool IsBoundary = false>
N
niuliling123 已提交
370
__device__ __forceinline__ void ReadDataBc(
371 372 373 374 375 376
    T* dst,
    const T* __restrict__ src,
    uint32_t block_offset,
    details::BroadcastConfig<Rank> config,
    int total_num_output,
    int stride_nx,
377
    int stride_ny) {
378
  uint32_t thread_offset = block_offset + threadIdx.x;
379
  uint32_t index_src = 0;
N
niuliling123 已提交
380 381 382 383 384

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
385 386
      uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
      index_src = 0;
387
      if (IsBoundary) {
388
        if (index_output >= total_num_output) {
389
          break;
N
niuliling123 已提交
390 391
        }
      }
392
#pragma unroll
393
      for (int i = 0; i < Rank; ++i) {
394 395 396
        auto fast_divmoder = config.divmoders[i].Divmod(index_output);
        index_output = fast_divmoder.val[0];
        index_src += fast_divmoder.val[1] * config.strides[i];
397
      }
398
      dst[nx + ny * NX] = src[index_src];
N
niuliling123 已提交
399 400 401 402
    }
  }
}

403
/**
404
 * @brief Read 2D data from global memory to register with reduce form.
405 406
 *
 * @template paraments
407
 * T: The type of data.
408 409 410
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
411
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
412 413 414
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
415
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
416 417
 * crossing the boundary.
 *
418
 * @param:
419
 * dst: The register pointer of the thread, the size is NX * NY.
420 421
 * src: The input data pointer of this block.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
422
 * index_cal: Calculation configuration of Reduce. It is used to calculate the
423
 * coordinate mapping relationship between output data and input data.
424
 * size_nx: The current block needs to load size_nx columns of data, this
425 426 427
 * parameter will participate in the calculation when isboundary = true.
 * size_ny: The current block needs to load size_ny rows of data, this parameter
 * will participate in the calculation when isboundary = true.
428
 * will be used when IsBoundary = true.
429 430
 * stride_nx: Each read one element stride stride_nx columns.
 * stride_ny: Each read one element stride stride_ny raws.
431 432
 * reduce_last_dim: Used to indicate whether the dimension of reduce contains
 * the lowest dimension.
433
 */
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int BlockSize,
          int Rank,
          typename IndexCal,
          typename Functor,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(Ty* dst,
                                               const Tx* __restrict__ src,
                                               int block_offset,
                                               const IndexCal& index_cal,
                                               int size_nx,
                                               int size_ny,
                                               int stride_nx,
                                               int stride_ny,
                                               Functor func,
                                               bool reduce_last_dim) {
453
  int thread_offset = 0;
454
  int left_idx = 0;
455
  if (reduce_last_dim) {
456 457
    thread_offset = threadIdx.x;
    left_idx = threadIdx.y;
458
  } else {
459 460
    thread_offset = threadIdx.y;
    left_idx = threadIdx.x;
461 462 463
  }

  if (NX == 1) {
N
niuliling123 已提交
464
#pragma unroll
465 466
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
467
        if (thread_offset >= size_ny) {
468 469 470
          break;
        }
      }
471
      uint32_t index_src = index_cal(thread_offset + block_offset);
472
      dst[ny] = static_cast<Ty>(func(src[index_src]));
473
      thread_offset += stride_ny;
474 475 476 477 478 479 480
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
481 482
          if ((thread_offset >= size_ny) ||
              (left_idx + nx * stride_nx >= size_nx)) {
483 484 485
            break;
          }
        }
486
        uint32_t index_src = index_cal(thread_offset + block_offset);
487
        dst[nx + ny * NX] = static_cast<Ty>(func(src[index_src]));
488
        thread_offset += stride_ny;
489
      }
N
niuliling123 已提交
490 491
    }
  }
F
Feng Xing 已提交
492
}
N
niuliling123 已提交
493

494
/**
495 496 497 498 499 500
 * @brief Write 2D data from registers to global memory. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
 * data loading efficiency
 *
 * @template paraments
 * T: The type of data.
501
 * NX: The number of data continuously writed by each thread.
502 503
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
504
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
505 506
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
507
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
508 509
 * crossing the boundary.
 *
510
 * @param:
511 512 513
 * dst: The data pointer of the current block.
 * src: The register pointer, the size is NX * NY.
 * size: The current block needs to load size elements continuously.
514 515
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
516 517
__device__ __forceinline__ void WriteData(T* dst,
                                          T* __restrict__ src,
518 519
                                          int num) {
  if (IsBoundary) {
520
    int thread_offset = threadIdx.x * NX;
521 522
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
523 524
      if ((thread_offset + idx) < num) {
        dst[thread_offset + idx] = src[idx];
525 526
      }
    }
N
niuliling123 已提交
527 528
  } else {
    // Vector type
529 530
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
531

532
    int thread_offset = threadIdx.x * kVectorsPerThread;
533 534 535
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
N
niuliling123 已提交
536
#pragma unroll
537
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
N
niuliling123 已提交
538
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
539
      vec_dst[thread_offset + idx] = vec_temp[idx];
N
niuliling123 已提交
540 541
    }
  }
F
Feng Xing 已提交
542
}
N
niuliling123 已提交
543

544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
/**
 * @brief Write 2D data from register to global memory according to Tx type, and
 * store it as Ty type.
 *
 * @template paraments
 * Tx: The type of data that needs to be stored in registers.
 * Ty: The type of data that stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The data pointer of the current block.
 * src: The register pointer of the thread, the size is NX * NY.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
570 571 572 573 574
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int BlockSize,
575
          bool IsBoundary = false>
576 577 578 579 580 581
__device__ __forceinline__ void WriteData(Ty* dst,
                                          const Tx* __restrict__ src,
                                          int size_nx,
                                          int size_ny,
                                          int stride_nx,
                                          int stride_ny) {
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
  int thread_offset = threadIdx.x;
  int left_size_nx = size_nx - thread_offset;

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
      if (left_size_nx > 0) {
        dst[thread_offset] = static_cast<Ty>(src[0]);
      }
    } else {
      dst[thread_offset] = static_cast<Ty>(src[0]);
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy * stride_ny >= size_ny) {
          break;
        }
      }
      dst[thread_offset + idy * stride_ny] = static_cast<Ty>(src[idy]);
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
      dst[thread_offset + idx * stride_nx] = static_cast<Ty>(src[idx]);
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy * stride_ny >= size_ny) {
            break;
          }
        }
        dst[thread_offset + idx * stride_nx + idy * stride_ny] =
            static_cast<Ty>(src[idy * NX + idx]);
      }
    }
  }
}

/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
template <typename T, int NX, bool IsBoundary = false>
__device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    if (IsBoundary) {
      if (i >= num) {
        break;
      }
    }
    dst[i] = init_data[i];
  }
}

/**
 * @brief Read 1D data from global memory to register with broadcast form.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The original input data pointer of kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * total_num_output: Total number of original output.
 */
683 684 685 686 687
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          int Rank,
688 689
          bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
690 691 692 693 694
    T* dst,
    const T* __restrict__ src,
    uint32_t block_offset,
    details::BroadcastConfig<Rank> config,
    int total_num_output) {
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
  uint32_t thread_offset = block_offset + threadIdx.x * NX;
  uint32_t index_src = 0;

#pragma unroll
  for (uint32_t nx = 0; nx < NX; ++nx) {
    uint32_t index_output = thread_offset + nx;
    index_src = 0;
    if (IsBoundary) {
      if (index_output >= total_num_output) {
        break;
      }
    }
#pragma unroll
    for (int i = 0; i < Rank; ++i) {
      auto fast_divmoder = config.divmoders[i].Divmod(index_output);
      index_output = fast_divmoder.val[0];
      index_src += fast_divmoder.val[1] * config.strides[i];
    }
    dst[nx] = src[index_src];
  }
}

717
}  // namespace kps
718
}  // namespace phi