datamover_primitives.h 13.2 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#ifdef PADDLE_WITH_CUDA
N
niuliling123 已提交
17 18
#include <cuda.h>
#include <cuda_fp16.h>
19 20 21 22
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
F
Feng Xing 已提交
23 24 25

namespace paddle {
namespace operators {
N
niuliling123 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
namespace kernel_primitives {
namespace details {

#define INT_BITS 32

template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};

struct FastDivMod {
  // 1st value represents the result of input number divides by recorded divisor
  // 2nd value represents the result of input number modulo by recorded divisor
  using DivModT = VectorType<uint32_t, 2>;

  FastDivMod() {}
  HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
    static_assert(sizeof(unsigned int) == 4,
                  "Only Support 32-bit unsigned int.");

    for (shift_val = 0; shift_val < INT_BITS; ++shift_val) {
      auto shift_limit = 1 << shift_val;
      if (shift_limit >= divisor) break;
    }
    uint64_t long_one = 1;
    uint64_t temp_div =
        ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) /
            divisor +
        1;
    multiplier = temp_div;
  }

  __device__ __forceinline__ uint32_t Div(uint32_t n) const {
    uint32_t t = __umulhi(n, multiplier);
    return (t + n) >> shift_val;
  }

  __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
    uint32_t q = Div(n);
    DivModT result = {q, n - q * divisor};
    return result;
  }

  int32_t divisor;
  int32_t shift_val;
  uint32_t multiplier;
};

template <int kDims>
struct BroadcastConfig {
  FastDivMod divmoders[kDims];
  uint32_t strides[framework::DDim::kMaxRank];
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
                             int dim_size) {
    std::vector<uint32_t> strides_in;
    std::vector<FastDivMod> divmoders_in;
    // for divmoders
    divmoders_in.resize(dim_size);
    for (int i = 0; i < dim_size; ++i) {
      divmoders_in[i] = FastDivMod(out_dims[i]);
    }
    // for strides
    strides_in.resize(dim_size, 1);
    for (int i = 0; i < dim_size; ++i) {
      strides_in[i] = in_dims[i] == 1 ? 0 : strides_in[i];
      strides_in[i] =
          (i != 0 && strides_in[i] != 0)
              ? std::accumulate(in_dims.begin(), in_dims.begin() + i, 1,
                                std::multiplies<int64_t>())
              : strides_in[i];
    }

    memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
    memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
  }
};

#undef INT_BITS
}  // namespace details

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
/**
 * @brief load data from src to dst, src can be 1D data or 2D data. Note that
 * you can use this function when you are sure that the data will not cross the
 * boundary.
 * @typename:
 * Tx: data type of src
 * Ty: data type of dstt
 * NX: the cols of src, dst
 * NY: the rows of src, dst
 * BlockSize: the config of this device
 * @param:
 * stride_nx: the stride of cols
 * stride_ny: the stride of rows
 */

template <typename Tx, typename Ty, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
                                         int stride_nx, int stride_ny) {
  if (NY == 1 && NX == 1) {
    dst[0] = static_cast<Ty>(src[threadIdx.x]);
  } else if (NX == 1) {
    int dx = threadIdx.x;
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
    }
  } else if (NY == 1) {
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      dst[idx] = static_cast<Ty>(src[idx * stride_nx]);
    }
  } else {
    int dx = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        dst[idy * NX + idx] =
            static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
      }
    }
  }
}

/**
 * @brief load data from src to dst, src can be 1D data or 2D data. When
 * boundary judgment is required, you need to set a to true, and a is false by
 * default.
 * @typename:
 * Tx: data type of src
 * Ty: data type of dstt
 * NX: the cols of src, dst
 * NY: the rows of src, dst
 * BlockSize: the config of this device
 * IsBoundary: whether to make boundary judgment
 * @param:
 * size_nx: number of columns to be processed by the current block
 * size_ny: number of rows to be processed by the current block
 * stride_nx: the stride of cols
 * stride_ny: the stride of rows
 */
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
          bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
                                         int size_nx, int size_ny,
                                         int stride_nx, int stride_ny) {
N
niuliling123 已提交
175
  int dx = threadIdx.x * NX;
176 177 178 179 180 181 182 183 184 185 186 187
  int size = size_nx - dx;

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
      if (dx < size_nx) {
        dst[0] = static_cast<Ty>(src[dx]);
      }
    } else {
      dst[0] = static_cast<Ty>(src[dx]);
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
N
niuliling123 已提交
188
#pragma unroll
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy >= size_ny) {
          break;
        }
      }
      dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]);
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx >= size) {
          break;
        }
      }
      dst[idx] = static_cast<Ty>(src[idx * stride_nx + dx]);
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx >= size) {
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy >= size_ny) {
            break;
          }
        }
        dst[idy * NX + idx] =
            static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]);
      }
N
niuliling123 已提交
225 226 227 228
    }
  }
}

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

/** @brief: ReadData
 * @brief load data from src to dst, src can be 1D data, you should set NY = 1.
 * When boundary judgment is required, you need to set a to true, and a is false
 * by default.
 * @typename:
 * T : the data type of src
 * NX: the cols of src, dst
 * NY: in this function NY only can be 1
 * BlockSize: the config of this device
 * IsBoundary: whether to make boundary judgment
 * @param:
 * num: number of columns to be processed by the current block
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
251
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
252 253 254 255 256 257 258 259 260 261 262 263 264
                                         int num) {
  if (IsBoundary) {  // blockDim.x * NX > num
    int dx = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + dx < num) {
        dst[idx] = src[idx + dx];
      }
    }
  } else {  // blockDim,x * NX < num
    const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    const int kVectorsPerThread = NX / kVectorSize;
    int tid = threadIdx.x * kVectorsPerThread;
N
niuliling123 已提交
265

266
    using VecType = details::VectorType<T, kVectorSize>;
N
niuliling123 已提交
267
    const VecType* vec_input = reinterpret_cast<const VecType*>(src);
268 269
    VecType vec_temp[kVectorsPerThread];

N
niuliling123 已提交
270
#pragma unroll
271 272 273 274 275 276
    for (int i = 0; i < kVectorsPerThread; ++i) {
      vec_temp[i] = vec_input[i + tid];
#pragma unroll
      for (int idx = 0; idx < NX; ++idx) {
        dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
      }
N
niuliling123 已提交
277 278 279 280
    }
  }
}

281 282 283 284 285 286 287 288 289 290
/**
 * @brief: read data for broadcast
 * @typename:
 * T : the data type of src
 * NX: the cols of src, dst
 * NY: in this function NY only can be 1
 * BlockSize: the config of this device
 * ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
 * is 2
 * IsBoundary: whether to make boundary judgment
N
niuliling123 已提交
291
 * @param:
292 293 294 295 296
 * fix: data offset of this block, blockDim.x * blockIdx.x * NX;
 * config: get the global index in src, attention config was declared in host;
 * num: the num of out
 * stride_nx: the stride of cols
 * stride_ny: the stride of rows
N
niuliling123 已提交
297
 */
298 299
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
          bool IsBoundary = false>
N
niuliling123 已提交
300 301 302 303 304 305 306 307 308 309 310 311
__device__ __forceinline__ void ReadDataBc(
    T* dst, const T* __restrict__ src, uint32_t fix,
    details::BroadcastConfig<ShapeSize> config, int num, int stride_nx,
    int stride_ny) {
  uint32_t base_offset = fix + threadIdx.x * NX;
  uint32_t offset = 0;

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
      uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx;
312 313 314
      if (IsBoundary) {
        if (idx >= num) {
          break;
N
niuliling123 已提交
315 316
        }
      }
317 318 319 320 321 322 323 324
      offset = 0;
#pragma unroll
      for (int i = 0; i < ShapeSize; ++i) {
        auto fast_divmoder = config.divmoders[i].Divmod(idx);
        idx = fast_divmoder.val[0];
        offset += fast_divmoder.val[1] * config.strides[i];
      }
      dst[nx + ny * NX] = src[offset];
N
niuliling123 已提交
325 326 327 328
    }
  }
}

329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
/**
 * @brief: read data for broadcast
 * @typename:
 * T : the data type of src
 * NX: the cols of src, dst
 * NY: in this function NY only can be 1
 * BlockSize: the config of this device
 * ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
 * is 2
 * IndexCal: get the global index in src, attention config was declared in host;
 * IsBoundary: whether to make boundary judgment
 * @param:
 * fix: data offset of this block, blockDim.x * blockIdx.x * NX;
 * index_cal: get the global index in src, attention config was declared in
 * host;
 * size_nx: number of columns to be processed by the current block
 * size_ny: number of rows to be processed by the current block
 * stride_nx: the stride of cols
 * stride_ny: the stride of rows
 * reduce_last_dim: according to the block split set threadIdx
 */
template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
          typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
    T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal,
    int size_nx, int size_ny, int stride_nx, int stride_ny,
    bool reduce_last_dim) {
  int base_offset = fix;
  if (reduce_last_dim) {
    base_offset += threadIdx.x;
  } else {
    base_offset += threadIdx.y;
  }

  if (NX == 1) {
N
niuliling123 已提交
364
#pragma unroll
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
        if (base_offset >= size_ny) {
          break;
        }
      }
      uint32_t offset = index_cal(base_offset);
      dst[ny] = src[offset];
      base_offset += stride_ny;
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
      if (IsBoundary) {
        if (nx * stride_nx >= size_nx) {
          break;
        }
      }
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
          if (nx * stride_nx >= size_nx) {
            break;
          }
        }
        uint32_t offset = index_cal(base_offset);
        dst[nx + ny * NX] = src[offset];
        base_offset += stride_ny;
      }
N
niuliling123 已提交
394 395
    }
  }
F
Feng Xing 已提交
396
}
N
niuliling123 已提交
397

398 399 400 401 402 403 404 405 406 407 408 409 410 411
/** @brief: WriteData
 * @brief store data from src to dst, src can be 1D data, you should set NY = 1.
 * When boundary judgment is required, you need to set a to true, and a is false
 * by default.
 * @typename:
 * T : the data type of src
 * NX: the cols of src, dst
 * NY: in this function NY only can be 1
 * BlockSize: the config of this device
 * IsBoundary: whether to make boundary judgment
 * @param:
 * num: number of columns to be processed by the current block
 */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
N
niuliling123 已提交
412
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
413 414 415 416 417 418 419 420 421
                                          int num) {
  if (IsBoundary) {
    int dx = threadIdx.x * NX;
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if ((idx + dx) < num) {
        dst[idx + dx] = src[idx];
      }
    }
N
niuliling123 已提交
422 423
  } else {
    // Vector type
424 425 426 427 428 429 430
    const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
    const int kVectorsPerThread = NX / kVectorSize;

    int dx = threadIdx.x * kVectorsPerThread;
    using VecType = details::VectorType<T, kVectorSize>;
    VecType* vec_dst = reinterpret_cast<VecType*>(dst);
    VecType vec_temp[kVectorsPerThread];
N
niuliling123 已提交
431
#pragma unroll
432
    for (int idx = 0; idx < kVectorsPerThread; ++idx) {
N
niuliling123 已提交
433
      vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
434
      vec_dst[dx + idx] = vec_temp[idx];
N
niuliling123 已提交
435 436
    }
  }
F
Feng Xing 已提交
437
}
N
niuliling123 已提交
438 439 440 441

}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle