datamover_primitives.h 28.5 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#ifdef PADDLE_WITH_CUDA
N
niuliling123 已提交
17 18
#include <cuda.h>
#include <cuda_fp16.h>
19 20 21 22
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
23
#include "paddle/phi/core/ddim.h"
F
Feng Xing 已提交
24

25
namespace phi {
26
namespace kps {
N
niuliling123 已提交
27 28 29 30 31 32 33 34
namespace details {

#define INT_BITS 32

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};
35 36 37 38 39 40 41
/**
 * Fast division : Replace division in CUDA with multiplication to improve
 * kernel performance.
 * 1. Complete the division calculation on the CPU, and record the calculation
 * results by using the divider and shift_val.
 * 2. Set the divisor on the GPU through Div() to complete the calculation.
 */
N
niuliling123 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
struct FastDivMod {
  // 1st value represents the result of input number divides by recorded divisor
  // 2nd value represents the result of input number modulo by recorded divisor
  using DivModT = VectorType<uint32_t, 2>;

  FastDivMod() {}
  HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
    static_assert(sizeof(unsigned int) == 4,
                  "Only Support 32-bit unsigned int.");

    for (shift_val = 0; shift_val < INT_BITS; ++shift_val) {
      auto shift_limit = 1 << shift_val;
      if (shift_limit >= divisor) break;
    }
    uint64_t long_one = 1;
    uint64_t temp_div =
        ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) /
            divisor +
        1;
    multiplier = temp_div;
  }

  __device__ __forceinline__ uint32_t Div(uint32_t n) const {
    uint32_t t = __umulhi(n, multiplier);
    return (t + n) >> shift_val;
  }

  __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
    uint32_t q = Div(n);
    DivModT result = {q, n - q * divisor};
    return result;
  }

  int32_t divisor;
  int32_t shift_val;
  uint32_t multiplier;
};

80 81 82 83 84
/**
 * Configuration of broadcast. Calculate the input data index according to the
 * index of the output data. if input or output shape is [dim0, dim1] then dims
 * must be [dim1, dim0].
 */
N
niuliling123 已提交
85
struct BroadcastConfig {
86
  FastDivMod divmoders[phi::DDim::kMaxRank];
87
  uint32_t strides[phi::DDim::kMaxRank];
88
  int kDims;
N
niuliling123 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
                             int dim_size) {
    std::vector<uint32_t> strides_in;
    std::vector<FastDivMod> divmoders_in;
    // for divmoders
    divmoders_in.resize(dim_size);
    for (int i = 0; i < dim_size; ++i) {
      divmoders_in[i] = FastDivMod(out_dims[i]);
    }
    // for strides
    strides_in.resize(dim_size, 1);
    for (int i = 0; i < dim_size; ++i) {
      strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i];
105 106 107 108 109 110
      strides_in[i] = (i != 0 && strides_in[i] != 0)
                          ? std::accumulate(in_dims.begin(),
                                            in_dims.begin() + i,
                                            1,
                                            std::multiplies<int64_t>())
                          : strides_in[i];
N
niuliling123 已提交
111
    }
112
    kDims = dim_size;
N
niuliling123 已提交
113 114 115 116 117
    memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
    memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
  }
};

118 119 120 121 122 123 124 125
template <typename T>
__device__ __forceinline__ void WriteData(T* dst,
                                          T* __restrict__ src,
                                          int num) {
  for (int i = 0; i < num; i++) {
    dst[i] = src[i];
  }
}
126 127 128 129 130 131 132 133 134

template <typename T>
__device__ __forceinline__ void ReadData(T* dst,
                                         const T* __restrict__ src,
                                         int num) {
  for (int i = 0; i < num; i++) {
    dst[i] = src[i];
  }
}
N
niuliling123 已提交
135 136 137
#undef INT_BITS
}  // namespace details

138
/**
139 140
 * @brief Read 2D data from global memory to register according to Tx type, and
 * store it as Ty type into register.
141 142 143 144 145 146
 *
 * @template paraments
 * Tx: The type of data stored in the global memory.
 * Ty: The type of data that needs to be stored in registers.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
147
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
148 149 150 151 152
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
153
 * @param:
154
 * dst: The register pointer of the thread, the size is NX * NY.
155 156 157 158 159 160 161
 * src: The data pointer of the current block.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
162
 */
163
template <typename Tx, typename Ty, int NX, int NY, bool IsBoundary = false>
164 165 166 167 168 169
__device__ __forceinline__ void ReadData(Ty* dst,
                                         const Tx* __restrict__ src,
                                         int size_nx,
                                         int size_ny,
                                         int stride_nx,
                                         int stride_ny) {
170
  int thread_offset = threadIdx.x;
171
  int left_size_nx = size_nx - thread_offset;
172 173 174 175

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
176 177
      if (left_size_nx > 0) {
        dst[0] = static_cast<Ty>(src[thread_offset]);
178 179
      }
    } else {
180
      dst[0] = static_cast<Ty>(src[thread_offset]);
181 182
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
N
niuliling123 已提交
183
#pragma unroll
184 185
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
186
        if (idy * stride_ny >= size_ny) {
187 188 189
          break;
        }
      }
190
      dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
191 192 193 194 195
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
196
        if (idx * stride_nx >= left_size_nx) {
197 198 199
          break;
        }
      }
200
      dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
201 202 203 204 205
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
206
        if (idx * stride_nx >= left_size_nx) {
207 208 209 210 211 212
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
213
          if (idy * stride_ny >= size_ny) {
214 215 216
            break;
          }
        }
217 218
        dst[idy * NX + idx] = static_cast<Ty>(
            src[thread_offset + idx * stride_nx + idy * stride_ny]);
219
      }
N
niuliling123 已提交
220 221 222 223
    }
  }
}

224 225 226 227 228 229 230 231 232 233 234
/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: Initial value.
 */
235 236 237 238 239 240 241 242
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

243 244 245 246 247 248 249 250
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

251 252 253 254 255
/**
 * The difference from the above function is that
 * it supports different data types of inputs.
 */
template <typename T, typename ArgsT, int Index, int NX>
256
__device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) {
257 258 259 260 261 262
#pragma unroll
  for (int i = 0; i < NX; i++) {
    std::get<Index>(dst[i]) = init_data;
  }
}

263
/**
264
 * @brief Read 1D data from global memory to register. When IsBoundary = true
265 266 267 268
 * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
 * improve memory access efficiency.
 *
 * @template paraments
269 270 271 272
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
273 274
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
275
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
276 277
 * crossing the boundary.
 *
278
 * @param:
279
 * dst: The register pointer of the thread, the size is NX * NY.
280
 * src: The data pointer of the current block.
281
 * size: The current block needs to load size data continuously.
282
 */
283
template <typename T, int NX, int NY, bool IsBoundary = false>
284 285
__device__ __forceinline__ void ReadData(T* dst,
                                         const T* __restrict__ src,
286 287
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
288
    int thread_offset = threadIdx.x * NX;
289 290
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
291 292
      if (idx + thread_offset < num) {
        dst[idx] = src[thread_offset + idx];
293 294 295
      }
    }
  } else {  // blockDim,x * NX < num
296 297
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
298
    int thread_offset = threadIdx.x * kVectorsPerThread;
N
niuliling123 已提交
299

300
    using VecType = details::VectorType<T, kVectorSize>;
N
niuliling123 已提交
301
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
302 303
    VecType vec_temp[kVectorsPerThread];

N
niuliling123 已提交
304
#pragma unroll
305
    for (int i = 0; i < kVectorsPerThread; ++i) {
306
      vec_temp[i] = vec_input[thread_offset + i];
307 308 309 310
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
N
niuliling123 已提交
311 312 313 314
    }
  }
}

315
template <typename T, int NX, int NY, bool IsBoundary = false>
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
__device__ __forceinline__ void ReadData(T* dst,
                                         const T* __restrict__ src,
                                         int num,
                                         int read_lens) {
  if (IsBoundary) {  // blockDim.x * NX > num
    int thread_offset = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        dst[idx] = src[thread_offset + idx];
      }
    }
  } else {  // blockDim,x * NX < num
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
    int thread_offset = threadIdx.x * kVectorsPerThread;

    using VecType = details::VectorType<T, kVectorSize>;
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
    VecType vec_temp[kVectorsPerThread];

#pragma unroll
    for (int i = 0; i < kVectorsPerThread; ++i) {
      vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
    }
  }
}
347 348 349
/**
 * @brief Read 1D data from global memory to register. The difference
 * from the above function is that it supports different data types of inputs.
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
 *
 * @template paraments
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
 * ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
 * Index: The index of data stored in dst.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The data pointer of the current block.
 * size: The current block needs to load size data continuously.
367 368 369 370 371 372 373 374 375
 */
template <typename T,
          int NX,
          int NY,
          typename ArgsT,
          int Index,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadData(ArgsT* dst,
                                         const T* __restrict__ src,
376 377
                                         int num,
                                         int read_lens) {
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  if (IsBoundary) {  // blockDim.x * NX > num
    int thread_offset = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        std::get<Index>(dst[idx]) = src[thread_offset + idx];
      }
    }
  } else {  // blockDim,x * NX < num
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
    int thread_offset = threadIdx.x * kVectorsPerThread;

    using VecType = details::VectorType<T, kVectorSize>;
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
    VecType vec_temp[kVectorsPerThread];

#pragma unroll
    for (int i = 0; i < kVectorsPerThread; ++i) {
      vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        std::get<Index>(dst[idx]) = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
    }
  }
}

406
/**
407
 * @brief Read 2D data from global memory to registers with broadcast form.
408 409 410 411 412
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
413
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
414 415 416
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
417
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
418 419
 * crossing the boundary.
 *
N
niuliling123 已提交
420
 * @param:
421
 * dst: The register pointer of the thread, the size is NX * NY.
422 423
 * src: The original input data pointer of this kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
424
 * config: Calculation configuration of broadcast. It is used to calculate the
425
 * coordinate mapping relationship between output data and input data.
426
 * total_num_output: Total number of original output.
427 428
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
N
niuliling123 已提交
429
 */
430
template <typename T, int NX, int NY, bool IsBoundary = false>
N
niuliling123 已提交
431
__device__ __forceinline__ void ReadDataBc(
432 433 434
    T* dst,
    const T* __restrict__ src,
    uint32_t block_offset,
435
    const details::BroadcastConfig& config,
436 437
    int total_num_output,
    int stride_nx,
438
    int stride_ny) {
439
  uint32_t thread_offset = block_offset + threadIdx.x;
440
  uint32_t index_src = 0;
N
niuliling123 已提交
441 442 443 444 445

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
446 447
      uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
      index_src = 0;
448
      if (IsBoundary) {
449
        if (index_output >= total_num_output) {
450
          break;
N
niuliling123 已提交
451 452
        }
      }
453
#pragma unroll
454 455
      for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
        if (i >= config.kDims) break;
456 457 458
        auto fast_divmoder = config.divmoders[i].Divmod(index_output);
        index_output = fast_divmoder.val[0];
        index_src += fast_divmoder.val[1] * config.strides[i];
459
      }
460
      dst[nx + ny * NX] = src[index_src];
N
niuliling123 已提交
461 462 463 464
    }
  }
}

465
/**
466
 * @brief Read 2D data from global memory to register with reduce form.
467 468
 *
 * @template paraments
469
 * T: The type of data.
470 471
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
472
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
473 474 475
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
476
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
477 478
 * crossing the boundary.
 *
479
 * @param:
480
 * dst: The register pointer of the thread, the size is NX * NY.
481 482
 * src: The input data pointer of this block.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
483
 * index_cal: Calculation configuration of Reduce. It is used to calculate the
484
 * coordinate mapping relationship between output data and input data.
485
 * size_nx: The current block needs to load size_nx columns of data, this
486 487 488
 * parameter will participate in the calculation when isboundary = true.
 * size_ny: The current block needs to load size_ny rows of data, this parameter
 * will participate in the calculation when isboundary = true.
489
 * will be used when IsBoundary = true.
490 491
 * stride_nx: Each read one element stride stride_nx columns.
 * stride_ny: Each read one element stride stride_ny raws.
492 493
 * reduce_last_dim: Used to indicate whether the dimension of reduce contains
 * the lowest dimension.
494
 */
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int Rank,
          typename IndexCal,
          typename Functor,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(Ty* dst,
                                               const Tx* __restrict__ src,
                                               int block_offset,
                                               const IndexCal& index_cal,
                                               int size_nx,
                                               int size_ny,
                                               int stride_nx,
                                               int stride_ny,
                                               Functor func,
                                               bool reduce_last_dim) {
513
  int thread_offset = 0;
514
  int left_idx = 0;
515
  if (reduce_last_dim) {
516 517
    thread_offset = threadIdx.x;
    left_idx = threadIdx.y;
518
  } else {
519 520
    thread_offset = threadIdx.y;
    left_idx = threadIdx.x;
521 522 523
  }

  if (NX == 1) {
N
niuliling123 已提交
524
#pragma unroll
525 526
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
527
        if (thread_offset >= size_ny) {
528 529 530
          break;
        }
      }
531
      uint32_t index_src = index_cal(thread_offset + block_offset);
532
      dst[ny] = static_cast<Ty>(func(src[index_src]));
533
      thread_offset += stride_ny;
534 535 536 537 538 539 540
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
541 542
          if ((thread_offset >= size_ny) ||
              (left_idx + nx * stride_nx >= size_nx)) {
543 544 545
            break;
          }
        }
546
        uint32_t index_src = index_cal(thread_offset + block_offset);
547
        dst[nx + ny * NX] = static_cast<Ty>(func(src[index_src]));
548
        thread_offset += stride_ny;
549
      }
N
niuliling123 已提交
550 551
    }
  }
F
Feng Xing 已提交
552
}
N
niuliling123 已提交
553

554
/**
555 556 557 558 559 560
 * @brief Write 2D data from registers to global memory. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
 * data loading efficiency
 *
 * @template paraments
 * T: The type of data.
561
 * NX: The number of data continuously writed by each thread.
562
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
563
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
564 565
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
566
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
567 568
 * crossing the boundary.
 *
569
 * @param:
570 571 572
 * dst: The data pointer of the current block.
 * src: The register pointer, the size is NX * NY.
 * size: The current block needs to load size elements continuously.
573
 */
574
template <typename T, int NX, int NY, bool IsBoundary = false>
575 576
__device__ __forceinline__ void WriteData(T* dst,
                                          T* __restrict__ src,
577 578
                                          int num) {
  if (IsBoundary) {
579
    int thread_offset = threadIdx.x * NX;
580 581
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
582 583
      if ((thread_offset + idx) < num) {
        dst[thread_offset + idx] = src[idx];
584 585
      }
    }
N
niuliling123 已提交
586 587
  } else {
    // Vector type
588 589
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
590

591
    int thread_offset = threadIdx.x * kVectorsPerThread;
592 593 594
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
N
niuliling123 已提交
595
#pragma unroll
596
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
N
niuliling123 已提交
597
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
598
      vec_dst[thread_offset + idx] = vec_temp[idx];
N
niuliling123 已提交
599 600
    }
  }
F
Feng Xing 已提交
601
}
N
niuliling123 已提交
602

603
template <typename T, int NX, int NY, bool IsBoundary = false>
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
__device__ __forceinline__ void WriteData(T* dst,
                                          T* __restrict__ src,
                                          int num,
                                          int read_lens) {
  if (IsBoundary) {
    int thread_offset = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if ((thread_offset + idx) < num) {
        dst[thread_offset + idx] = src[idx];
      }
    }
  } else {
    // Vector type
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;

    int thread_offset = threadIdx.x * kVectorsPerThread;
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
#pragma unroll
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
      vec_dst[thread_offset + idx] = vec_temp[idx];
    }
  }
}

633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
/**
 * @brief Write 2D data from register to global memory according to Tx type, and
 * store it as Ty type.
 *
 * @template paraments
 * Tx: The type of data that needs to be stored in registers.
 * Ty: The type of data that stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The data pointer of the current block.
 * src: The register pointer of the thread, the size is NX * NY.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
658
template <typename Tx, typename Ty, int NX, int NY, bool IsBoundary = false>
659 660 661 662 663 664
__device__ __forceinline__ void WriteData(Ty* dst,
                                          const Tx* __restrict__ src,
                                          int size_nx,
                                          int size_ny,
                                          int stride_nx,
                                          int stride_ny) {
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
  int thread_offset = threadIdx.x;
  int left_size_nx = size_nx - thread_offset;

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
      if (left_size_nx > 0) {
        dst[thread_offset] = static_cast<Ty>(src[0]);
      }
    } else {
      dst[thread_offset] = static_cast<Ty>(src[0]);
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy * stride_ny >= size_ny) {
          break;
        }
      }
      dst[thread_offset + idy * stride_ny] = static_cast<Ty>(src[idy]);
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
      dst[thread_offset + idx * stride_nx] = static_cast<Ty>(src[idx]);
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy * stride_ny >= size_ny) {
            break;
          }
        }
        dst[thread_offset + idx * stride_nx + idy * stride_ny] =
            static_cast<Ty>(src[idy * NX + idx]);
      }
    }
  }
}

/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
template <typename T, int NX, bool IsBoundary = false>
__device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    if (IsBoundary) {
      if (i >= num) {
        break;
      }
    }
    dst[i] = init_data[i];
  }
}

/**
 * @brief Read 1D data from global memory to register with broadcast form.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The original input data pointer of kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * total_num_output: Total number of original output.
 */
765
template <typename T, int NX, int NY, bool IsBoundary = false>
766 767 768 769
__device__ __forceinline__ void ReadDataBc(
    T* dst,
    const T* __restrict__ src,
    uint32_t block_offset,
770
    const details::BroadcastConfig& config,
771
    int total_num_output,
772
    int read_lens = NX) {
773 774 775 776 777 778 779 780 781 782 783 784 785
  uint32_t thread_offset = block_offset + threadIdx.x * NX;
  uint32_t index_src = 0;

#pragma unroll
  for (uint32_t nx = 0; nx < NX; ++nx) {
    uint32_t index_output = thread_offset + nx;
    index_src = 0;
    if (IsBoundary) {
      if (index_output >= total_num_output) {
        break;
      }
    }
#pragma unroll
786 787
    for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
      if (i >= config.kDims) break;
788 789 790 791 792 793 794
      auto fast_divmoder = config.divmoders[i].Divmod(index_output);
      index_output = fast_divmoder.val[0];
      index_src += fast_divmoder.val[1] * config.strides[i];
    }
    dst[nx] = src[index_src];
  }
}
795

796 797 798 799 800 801 802 803 804 805 806 807 808
/**
 * @brief Initialize register with data index.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 * NY: Number of data to initialize, NY only can be 1.
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
809
template <typename T, int NX, int NY>
810 811 812 813 814 815 816 817
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
  int thread_offset = block_offset + threadIdx.x * NX;
#pragma unroll
  for (int nx = 0; nx < NX; ++nx) {
    dst[nx] = static_cast<T>(thread_offset + nx);
  }
}

818
}  // namespace kps
819
}  // namespace phi