datamover_primitives.h 22.1 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#ifdef PADDLE_WITH_CUDA
N
niuliling123 已提交
17 18
#include <cuda.h>
#include <cuda_fp16.h>
19 20 21 22
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
F
Feng Xing 已提交
23 24 25

namespace paddle {
namespace operators {
N
niuliling123 已提交
26 27 28 29 30 31 32 33 34
namespace kernel_primitives {
namespace details {

#define INT_BITS 32

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};
35 36 37 38 39 40 41
/**
 * Fast division : Replace division in CUDA with multiplication to improve
 * kernel performance.
 * 1. Complete the division calculation on the CPU, and record the calculation
 * results by using the divider and shift_val.
 * 2. Set the divisor on the GPU through Div() to complete the calculation.
 */
N
niuliling123 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
struct FastDivMod {
  // 1st value represents the result of input number divides by recorded divisor
  // 2nd value represents the result of input number modulo by recorded divisor
  using DivModT = VectorType<uint32_t, 2>;

  FastDivMod() {}
  HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
    static_assert(sizeof(unsigned int) == 4,
                  "Only Support 32-bit unsigned int.");

    for (shift_val = 0; shift_val < INT_BITS; ++shift_val) {
      auto shift_limit = 1 << shift_val;
      if (shift_limit >= divisor) break;
    }
    uint64_t long_one = 1;
    uint64_t temp_div =
        ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) /
            divisor +
        1;
    multiplier = temp_div;
  }

  __device__ __forceinline__ uint32_t Div(uint32_t n) const {
    uint32_t t = __umulhi(n, multiplier);
    return (t + n) >> shift_val;
  }

  __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
    uint32_t q = Div(n);
    DivModT result = {q, n - q * divisor};
    return result;
  }

  int32_t divisor;
  int32_t shift_val;
  uint32_t multiplier;
};

80 81 82 83 84
/**
 * Configuration of broadcast. Calculate the input data index according to the
 * index of the output data. if input or output shape is [dim0, dim1] then dims
 * must be [dim1, dim0].
 */
N
niuliling123 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
template <int kDims>
struct BroadcastConfig {
  FastDivMod divmoders[kDims];
  uint32_t strides[framework::DDim::kMaxRank];
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
                             int dim_size) {
    std::vector<uint32_t> strides_in;
    std::vector<FastDivMod> divmoders_in;
    // for divmoders
    divmoders_in.resize(dim_size);
    for (int i = 0; i < dim_size; ++i) {
      divmoders_in[i] = FastDivMod(out_dims[i]);
    }
    // for strides
    strides_in.resize(dim_size, 1);
    for (int i = 0; i < dim_size; ++i) {
      strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i];
      strides_in[i] =
          (i != 0 && strides_in[i] != 0)
              ? std::accumulate(in_dims.begin(), in_dims.begin() + i, 1,
                                std::multiplies<int64_t>())
              : strides_in[i];
    }

    memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
    memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
  }
};

#undef INT_BITS
}  // namespace details

120
/**
121 122
 * @brief Read 2D data from global memory to register according to Tx type, and
 * store it as Ty type into register.
123 124 125 126 127 128 129
 *
 * @template paraments
 * Tx: The type of data stored in the global memory.
 * Ty: The type of data that needs to be stored in registers.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
130
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
131 132 133 134 135
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
136
 * @param:
137
 * dst: The register pointer of the thread, the size is NX * NY.
138 139 140 141 142 143 144
 * src: The data pointer of the current block.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
145 146 147 148 149 150
 */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
                                         int size_nx, int size_ny,
                                         int stride_nx, int stride_ny) {
151
  int thread_offset = threadIdx.x;
152
  int left_size_nx = size_nx - thread_offset;
153 154 155 156

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
157 158
      if (left_size_nx > 0) {
        dst[0] = static_cast<Ty>(src[thread_offset]);
159 160
      }
    } else {
161
      dst[0] = static_cast<Ty>(src[thread_offset]);
162 163
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
N
niuliling123 已提交
164
#pragma unroll
165 166
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
167
        if (idy * stride_ny >= size_ny) {
168 169 170
          break;
        }
      }
171
      dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
172 173 174 175 176
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
177
        if (idx * stride_nx >= left_size_nx) {
178 179 180
          break;
        }
      }
181
      dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
182 183 184 185 186
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
187
        if (idx * stride_nx >= left_size_nx) {
188 189 190 191 192 193
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
194
          if (idy * stride_ny >= size_ny) {
195 196 197
            break;
          }
        }
198 199
        dst[idy * NX + idx] = static_cast<Ty>(
            src[thread_offset + idx * stride_nx + idy * stride_ny]);
200
      }
N
niuliling123 已提交
201 202 203 204
    }
  }
}

205 206 207 208 209 210 211 212 213 214 215
/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: Initial value.
 */
216 217 218 219 220 221 222 223
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

224
/**
225
 * @brief Read 1D data from global memory to register. When IsBoundary = true
226 227 228 229
 * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
 * improve memory access efficiency.
 *
 * @template paraments
230 231 232
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
233
 * BlockSize: Identifies the current device thread index method. For GPU,
234
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
235 236
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
237
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
238 239
 * crossing the boundary.
 *
240
 * @param:
241
 * dst: The register pointer of the thread, the size is NX * NY.
242
 * src: The data pointer of the current block.
243
 * size: The current block needs to load size data continuously.
244 245
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
246
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
247 248
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
249
    int thread_offset = threadIdx.x * NX;
250 251
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
252 253
      if (idx + thread_offset < num) {
        dst[idx] = src[thread_offset + idx];
254 255 256
      }
    }
  } else {  // blockDim,x * NX < num
257 258
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
259
    int thread_offset = threadIdx.x * kVectorsPerThread;
N
niuliling123 已提交
260

261
    using VecType = details::VectorType<T, kVectorSize>;
N
niuliling123 已提交
262
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
263 264
    VecType vec_temp[kVectorsPerThread];

N
niuliling123 已提交
265
#pragma unroll
266
    for (int i = 0; i < kVectorsPerThread; ++i) {
267
      vec_temp[i] = vec_input[thread_offset + i];
268 269 270 271
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
N
niuliling123 已提交
272 273 274 275
    }
  }
}

276
/**
277
 * @brief Read 2D data from global memory to registers with broadcast form.
278 279 280 281 282 283
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
284
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
285 286 287
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
288
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
289 290
 * crossing the boundary.
 *
N
niuliling123 已提交
291
 * @param:
292
 * dst: The register pointer of the thread, the size is NX * NY.
293 294
 * src: The original input data pointer of this kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
295
 * config: Calculation configuration of broadcast. It is used to calculate the
296
 * coordinate mapping relationship between output data and input data.
297
 * total_num_output: Total number of original output.
298 299
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
N
niuliling123 已提交
300
 */
301
template <typename T, int NX, int NY, int BlockSize, int Rank,
302
          bool IsBoundary = false>
N
niuliling123 已提交
303
__device__ __forceinline__ void ReadDataBc(
304
    T* dst, const T* __restrict__ src, uint32_t block_offset,
305 306
    details::BroadcastConfig<Rank> config, int total_num_output, int stride_nx,
    int stride_ny) {
307
  uint32_t thread_offset = block_offset + threadIdx.x;
308
  uint32_t index_src = 0;
N
niuliling123 已提交
309 310 311 312 313

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
314 315
      uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
      index_src = 0;
316
      if (IsBoundary) {
317
        if (index_output >= total_num_output) {
318
          break;
N
niuliling123 已提交
319 320
        }
      }
321
#pragma unroll
322
      for (int i = 0; i < Rank; ++i) {
323 324 325
        auto fast_divmoder = config.divmoders[i].Divmod(index_output);
        index_output = fast_divmoder.val[0];
        index_src += fast_divmoder.val[1] * config.strides[i];
326
      }
327
      dst[nx + ny * NX] = src[index_src];
N
niuliling123 已提交
328 329 330 331
    }
  }
}

332
/**
333
 * @brief Read 2D data from global memory to register with reduce form.
334 335
 *
 * @template paraments
336
 * T: The type of data.
337 338 339
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
340
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
341 342 343
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
344
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
345 346
 * crossing the boundary.
 *
347
 * @param:
348
 * dst: The register pointer of the thread, the size is NX * NY.
349 350
 * src: The input data pointer of this block.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX.
351
 * index_cal: Calculation configuration of Reduce. It is used to calculate the
352
 * coordinate mapping relationship between output data and input data.
353
 * size_nx: The current block needs to load size_nx columns of data, this
354 355 356
 * parameter will participate in the calculation when isboundary = true.
 * size_ny: The current block needs to load size_ny rows of data, this parameter
 * will participate in the calculation when isboundary = true.
357
 * will be used when IsBoundary = true.
358 359
 * stride_nx: Each read one element stride stride_nx columns.
 * stride_ny: Each read one element stride stride_ny raws.
360 361
 * reduce_last_dim: Used to indicate whether the dimension of reduce contains
 * the lowest dimension.
362
 */
363 364
template <typename Tx, typename Ty, int NX, int NY, int BlockSize, int Rank,
          typename IndexCal, typename Functor, bool IsBoundary = false>
365
__device__ __forceinline__ void ReadDataReduce(
366
    Ty* dst, const Tx* __restrict__ src, int block_offset,
367
    const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
368
    int stride_ny, Functor func, bool reduce_last_dim) {
369
  int thread_offset = 0;
370
  int left_idx = 0;
371
  if (reduce_last_dim) {
372 373
    thread_offset = threadIdx.x;
    left_idx = threadIdx.y;
374
  } else {
375 376
    thread_offset = threadIdx.y;
    left_idx = threadIdx.x;
377 378 379
  }

  if (NX == 1) {
N
niuliling123 已提交
380
#pragma unroll
381 382
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
383
        if (thread_offset >= size_ny) {
384 385 386
          break;
        }
      }
387
      uint32_t index_src = index_cal(thread_offset + block_offset);
388
      dst[ny] = static_cast<Ty>(func(src[index_src]));
389
      thread_offset += stride_ny;
390 391 392 393 394 395 396
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
397 398
          if ((thread_offset >= size_ny) ||
              (left_idx + nx * stride_nx >= size_nx)) {
399 400 401
            break;
          }
        }
402
        uint32_t index_src = index_cal(thread_offset + block_offset);
403
        dst[nx + ny * NX] = static_cast<Ty>(func(src[index_src]));
404
        thread_offset += stride_ny;
405
      }
N
niuliling123 已提交
406 407
    }
  }
F
Feng Xing 已提交
408
}
N
niuliling123 已提交
409

410
/**
411 412 413 414 415 416
 * @brief Write 2D data from registers to global memory. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
 * data loading efficiency
 *
 * @template paraments
 * T: The type of data.
417
 * NX: The number of data continuously writed by each thread.
418 419
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
420
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
421 422
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
423
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
424 425
 * crossing the boundary.
 *
426
 * @param:
427 428 429
 * dst: The data pointer of the current block.
 * src: The register pointer, the size is NX * NY.
 * size: The current block needs to load size elements continuously.
430 431
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
432
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
433 434
                                          int num) {
  if (IsBoundary) {
435
    int thread_offset = threadIdx.x * NX;
436 437
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
438 439
      if ((thread_offset + idx) < num) {
        dst[thread_offset + idx] = src[idx];
440 441
      }
    }
N
niuliling123 已提交
442 443
  } else {
    // Vector type
444 445
    constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    constexpr int kVectorsPerThread = NX / kVectorSize;
446

447
    int thread_offset = threadIdx.x * kVectorsPerThread;
448 449 450
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
N
niuliling123 已提交
451
#pragma unroll
452
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
N
niuliling123 已提交
453
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
454
      vec_dst[thread_offset + idx] = vec_temp[idx];
N
niuliling123 已提交
455 456
    }
  }
F
Feng Xing 已提交
457
}
N
niuliling123 已提交
458

459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
/**
 * @brief Write 2D data from register to global memory according to Tx type, and
 * store it as Ty type.
 *
 * @template paraments
 * Tx: The type of data that needs to be stored in registers.
 * Ty: The type of data that stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The data pointer of the current block.
 * src: The register pointer of the thread, the size is NX * NY.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
          bool IsBoundary = false>
__device__ __forceinline__ void WriteData(Ty* dst, const Tx* __restrict__ src,
                                          int size_nx, int size_ny,
                                          int stride_nx, int stride_ny) {
  int thread_offset = threadIdx.x;
  int left_size_nx = size_nx - thread_offset;

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
      if (left_size_nx > 0) {
        dst[thread_offset] = static_cast<Ty>(src[0]);
      }
    } else {
      dst[thread_offset] = static_cast<Ty>(src[0]);
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy * stride_ny >= size_ny) {
          break;
        }
      }
      dst[thread_offset + idy * stride_ny] = static_cast<Ty>(src[idy]);
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
      dst[thread_offset + idx * stride_nx] = static_cast<Ty>(src[idx]);
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy * stride_ny >= size_ny) {
            break;
          }
        }
        dst[thread_offset + idx * stride_nx + idy * stride_ny] =
            static_cast<Ty>(src[idy * NX + idx]);
      }
    }
  }
}

/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
template <typename T, int NX, bool IsBoundary = false>
__device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    if (IsBoundary) {
      if (i >= num) {
        break;
      }
    }
    dst[i] = init_data[i];
  }
}

/**
 * @brief Read 1D data from global memory to register with broadcast form.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The original input data pointer of kernel.
 * block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * total_num_output: Total number of original output.
 */
template <typename T, int NX, int NY, int BlockSize, int Rank,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
    T* dst, const T* __restrict__ src, uint32_t block_offset,
    details::BroadcastConfig<Rank> config, int total_num_output) {
  uint32_t thread_offset = block_offset + threadIdx.x * NX;
  uint32_t index_src = 0;

#pragma unroll
  for (uint32_t nx = 0; nx < NX; ++nx) {
    uint32_t index_output = thread_offset + nx;
    index_src = 0;
    if (IsBoundary) {
      if (index_output >= total_num_output) {
        break;
      }
    }
#pragma unroll
    for (int i = 0; i < Rank; ++i) {
      auto fast_divmoder = config.divmoders[i].Divmod(index_output);
      index_output = fast_divmoder.val[0];
      index_src += fast_divmoder.val[1] * config.strides[i];
    }
    dst[nx] = src[index_src];
  }
}

N
niuliling123 已提交
618 619 620
}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle