datamover_primitives.h 16.4 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#ifdef PADDLE_WITH_CUDA
N
niuliling123 已提交
17 18
#include <cuda.h>
#include <cuda_fp16.h>
19 20 21 22
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
F
Feng Xing 已提交
23 24 25

namespace paddle {
namespace operators {
N
niuliling123 已提交
26 27 28 29 30 31 32 33 34
namespace kernel_primitives {
namespace details {

#define INT_BITS 32

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};
35 36 37 38 39 40 41
/**
 * Fast division : Replace division in CUDA with multiplication to improve
 * kernel performance.
 * 1. Complete the division calculation on the CPU, and record the calculation
 * results by using the divider and shift_val.
 * 2. Set the divisor on the GPU through Div() to complete the calculation.
 */
N
niuliling123 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
struct FastDivMod {
  // 1st value represents the result of input number divides by recorded divisor
  // 2nd value represents the result of input number modulo by recorded divisor
  using DivModT = VectorType<uint32_t, 2>;

  FastDivMod() {}
  HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
    static_assert(sizeof(unsigned int) == 4,
                  "Only Support 32-bit unsigned int.");

    for (shift_val = 0; shift_val < INT_BITS; ++shift_val) {
      auto shift_limit = 1 << shift_val;
      if (shift_limit >= divisor) break;
    }
    uint64_t long_one = 1;
    uint64_t temp_div =
        ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) /
            divisor +
        1;
    multiplier = temp_div;
  }

  __device__ __forceinline__ uint32_t Div(uint32_t n) const {
    uint32_t t = __umulhi(n, multiplier);
    return (t + n) >> shift_val;
  }

  __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
    uint32_t q = Div(n);
    DivModT result = {q, n - q * divisor};
    return result;
  }

  int32_t divisor;
  int32_t shift_val;
  uint32_t multiplier;
};

80 81 82 83 84
/**
 * Configuration of broadcast. Calculate the input data index according to the
 * index of the output data. if input or output shape is [dim0, dim1] then dims
 * must be [dim1, dim0].
 */
N
niuliling123 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
template <int kDims>
struct BroadcastConfig {
  FastDivMod divmoders[kDims];
  uint32_t strides[framework::DDim::kMaxRank];
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
                             int dim_size) {
    std::vector<uint32_t> strides_in;
    std::vector<FastDivMod> divmoders_in;
    // for divmoders
    divmoders_in.resize(dim_size);
    for (int i = 0; i < dim_size; ++i) {
      divmoders_in[i] = FastDivMod(out_dims[i]);
    }
    // for strides
    strides_in.resize(dim_size, 1);
    for (int i = 0; i < dim_size; ++i) {
      strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i];
      strides_in[i] =
          (i != 0 && strides_in[i] != 0)
              ? std::accumulate(in_dims.begin(), in_dims.begin() + i, 1,
                                std::multiplies<int64_t>())
              : strides_in[i];
    }

    memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
    memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
  }
};

#undef INT_BITS
}  // namespace details

120
/**
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
 * @brief Read 2D data from global memory to registers according to Tx type, and
 * store it as Ty type.
 *
 * @template paraments
 * Tx: The type of data stored in the global memory.
 * Ty: The type of data that needs to be stored in registers.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
 * the index. Currently only GPU was supported.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
137
 * @param:
138 139 140 141 142 143 144 145
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: Data pointer of the current block.
 * size_nx: The current block needs to load size_nx columns of data, this
 * parameter will be used when IsBoundary = true.
 * size_ny: The current block needs to load size_ny rows of data. This parameter
 * will be used when IsBoundary = true.
 * stride_nx: The stride of cols.
 * stride_ny: The stride of rows.
146 147 148 149 150 151
 */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
                                         int size_nx, int size_ny,
                                         int stride_nx, int stride_ny) {
152 153
  int thread_offset = threadIdx.x * NX;
  int left_size_nx = size_nx - thread_offset;
154 155 156 157

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
158 159
      if (left_size_nx > 0) {
        dst[0] = static_cast<Ty>(src[thread_offset]);
160 161
      }
    } else {
162
      dst[0] = static_cast<Ty>(src[thread_offset]);
163 164
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
N
niuliling123 已提交
165
#pragma unroll
166 167 168 169 170 171
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy >= size_ny) {
          break;
        }
      }
172
      dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
173 174 175 176 177
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
178
        if (idx >= left_size_nx) {
179 180 181
          break;
        }
      }
182
      dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
183 184 185 186 187
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
188
        if (idx >= left_size_nx) {
189 190 191 192 193 194 195 196 197 198
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy >= size_ny) {
            break;
          }
        }
199 200
        dst[idy * NX + idx] = static_cast<Ty>(
            src[thread_offset + idx * stride_nx + idy * stride_ny]);
201
      }
N
niuliling123 已提交
202 203 204 205
    }
  }
}

206 207 208 209 210 211 212 213 214 215 216
/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: Initial value.
 */
217 218 219 220 221 222 223 224
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
/**
 * @brief Read 2D data from global memory to registers. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
 * improve memory access efficiency.
 *
 * @template paraments
 * T: Data type of src and dst.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
 * the index. Currently only GPU was supported.
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
242
 * @param:
243 244 245
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: Data pointer of the current block.
 * size: The current block needs to load size data continuously.
246 247
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
248
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
249 250
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
251
    int thread_offset = threadIdx.x * NX;
252 253
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
254 255
      if (idx + thread_offset < num) {
        dst[idx] = src[thread_offset + idx];
256 257 258 259 260
      }
    }
  } else {  // blockDim,x * NX < num
    const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    const int kVectorsPerThread = NX / kVectorSize;
261
    int thread_offset = threadIdx.x * kVectorsPerThread;
N
niuliling123 已提交
262

263
    using VecType = details::VectorType<T, kVectorSize>;
N
niuliling123 已提交
264
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
265 266
    VecType vec_temp[kVectorsPerThread];

N
niuliling123 已提交
267
#pragma unroll
268
    for (int i = 0; i < kVectorsPerThread; ++i) {
269
      vec_temp[i] = vec_input[thread_offset + i];
270 271 272 273
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
N
niuliling123 已提交
274 275 276 277
    }
  }
}

278
/**
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
 * @brief Read 2D data from global memory to registers for broadcast.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
 * the index. Currently only GPU was supported.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
N
niuliling123 已提交
294
 * @param:
295 296 297 298 299 300 301 302 303
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: Raw input data pointer of kernel.
 * block_offset: Data offset of this block, blockDim.x * blockIdx.x * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data. Please
 * refer to the sample code for specific usage.
 * total_num_output: Total number of original output.
 * stride_nx: The stride of cols.
 * stride_ny: The stride of rows.
N
niuliling123 已提交
304
 */
305
template <typename T, int NX, int NY, int BlockSize, int Rank,
306
          bool IsBoundary = false>
N
niuliling123 已提交
307
__device__ __forceinline__ void ReadDataBc(
308
    T* dst, const T* __restrict__ src, uint32_t block_offset,
309 310
    details::BroadcastConfig<Rank> config, int total_num_output, int stride_nx,
    int stride_ny) {
311 312
  uint32_t thread_offset = block_offset + threadIdx.x * NX;
  uint32_t index_src = 0;
N
niuliling123 已提交
313 314 315 316 317

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
318 319
      uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
      index_src = 0;
320
      if (IsBoundary) {
321
        if (index_output >= total_num_output) {
322
          break;
N
niuliling123 已提交
323 324
        }
      }
325
#pragma unroll
326
      for (int i = 0; i < Rank; ++i) {
327 328 329
        auto fast_divmoder = config.divmoders[i].Divmod(index_output);
        index_output = fast_divmoder.val[0];
        index_src += fast_divmoder.val[1] * config.strides[i];
330
      }
331
      dst[nx + ny * NX] = src[index_src];
N
niuliling123 已提交
332 333 334 335
    }
  }
}

336
/**
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
 * @brief Read 2D data from global memory to registers for reduce.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
 * the index. Currently only GPU was supported.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
352
 * @param:
353 354 355 356 357 358
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: Raw input data pointer of kernel.
 * block_offset: Data offset of this block, blockDim.x * blockIdx.x * NX;
 * index_cal: Calculation configuration of Reduce. It is used to calculate the
 * coordinate mapping relationship between output data and input data. Please
 * refer to the sample code for specific usage.
359
 * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
360 361
 * index_cal: get the global index in src, attention config was declared in
 * host;
362 363 364 365 366 367 368 369
 * size_nx: The current block needs to load size_nx columns of data, this
 * parameter will be used when IsBoundary = true.
 * size_ny: The current block needs to load size_ny rows of data. This parameter
 * will be used when IsBoundary = true.
 * stride_nx: The stride of cols.
 * stride_ny: The stride of rows.
 * reduce_last_dim: Used to indicate whether the dimension of reduce contains
 * the lowest dimension.
370
 */
371
template <typename T, int NX, int NY, int BlockSize, int Rank,
372 373
          typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
374 375 376 377
    T* dst, const T* __restrict__ src, int block_offset,
    const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
    int stride_ny, bool reduce_last_dim) {
  int thread_offset = 0;
378
  if (reduce_last_dim) {
379
    thread_offset = block_offset + threadIdx.x;
380
  } else {
381
    thread_offset = block_offset + threadIdx.y;
382 383 384
  }

  if (NX == 1) {
N
niuliling123 已提交
385
#pragma unroll
386 387
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
388
        if (thread_offset >= size_ny) {
389 390 391
          break;
        }
      }
392 393 394
      uint32_t index_src = index_cal(thread_offset);
      dst[ny] = src[index_src];
      thread_offset += stride_ny;
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
      if (IsBoundary) {
        if (nx * stride_nx >= size_nx) {
          break;
        }
      }
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
          if (nx * stride_nx >= size_nx) {
            break;
          }
        }
411 412 413
        uint32_t index_src = index_cal(thread_offset);
        dst[nx + ny * NX] = src[index_src];
        thread_offset += stride_ny;
414
      }
415
      thread_offset += stride_nx;
N
niuliling123 已提交
416 417
    }
  }
F
Feng Xing 已提交
418
}
N
niuliling123 已提交
419

420
/**
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
 * @brief Write 2D data from registers to global memory. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
 * data loading efficiency
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
 * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
 * the index. Currently only GPU was supported.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x blockDim, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
437
 * @param:
438 439 440
 * dst: Data pointer of the current block.
 * src: The register pointer of the thread, the size is NX * NY.
 * size: The current block needs to load size data continuously.
441 442
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
443
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
444 445
                                          int num) {
  if (IsBoundary) {
446
    int thread_offset = threadIdx.x * NX;
447 448
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
449 450
      if ((thread_offset + idx) < num) {
        dst[thread_offset + idx] = src[idx];
451 452
      }
    }
N
niuliling123 已提交
453 454
  } else {
    // Vector type
455 456 457
    const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    const int kVectorsPerThread = NX / kVectorSize;

458
    int thread_offset = threadIdx.x * kVectorsPerThread;
459 460 461
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
N
niuliling123 已提交
462
#pragma unroll
463
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
N
niuliling123 已提交
464
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
465
      vec_dst[thread_offset + idx] = vec_temp[idx];
N
niuliling123 已提交
466 467
    }
  }
F
Feng Xing 已提交
468
}
N
niuliling123 已提交
469 470 471 472

}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle