gpu_primitives.h 16.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#ifdef PADDLE_WITH_CUDA
17
#include <cuda.h>
18 19 20 21
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
22
#include <stdio.h>
23

24
#include "paddle/fluid/platform/bfloat16.h"
25
#include "paddle/fluid/platform/complex.h"
26
#include "paddle/fluid/platform/float16.h"
27 28 29 30 31

namespace paddle {
namespace platform {

#define CUDA_ATOMIC_WRAPPER(op, T) \
32
  __device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
33 34 35 36

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

37 38 39 40 41
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
//                    to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

42 43
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
Y
Yu Yang 已提交
44 45
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
Y
Yu Yang 已提交
46 47 48
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int);  // NOLINT
Y
Yu Yang 已提交
49 50

CUDA_ATOMIC_WRAPPER(Add, int64_t) {
Y
Yu Yang 已提交
51 52
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
Y
Yu Yang 已提交
53
                "long long should be int64");
Y
Yu Yang 已提交
54
  return CudaAtomicAdd(
55 56
      reinterpret_cast<unsigned long long int *>(address),  // NOLINT
      static_cast<unsigned long long int>(val));            // NOLINT
Y
Yu Yang 已提交
57
}
58

59
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600)
60 61 62
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
63 64 65
  unsigned long long int *address_as_ull =                  // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
66 67 68

  do {
    assumed = old;
69 70
    old = atomicCAS(address_as_ull,
                    assumed,
71 72 73 74 75 76 77
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}
78 79 80 81 82 83 84 85 86 87 88 89 90
#endif

#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.

// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
D
dzhwinter 已提交
91
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
92 93
  float16 low_half;
  // the float16 in lower 16bits
D
dzhwinter 已提交
94
  low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
95
  low_half = static_cast<float16>(static_cast<float>(low_half) + x);
D
dzhwinter 已提交
96
  return (val & 0xFFFF0000u) | low_half.x;
97 98
}

D
dzhwinter 已提交
99
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
100 101 102 103
  float16 high_half;
  // the float16 in higher 16bits
  high_half.x = static_cast<uint16_t>(val >> 16);
  high_half = static_cast<float16>(static_cast<float>(high_half) + x);
D
dzhwinter 已提交
104
  return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
105 106
}

107 108 109 110 111 112 113 114 115 116 117 118 119 120
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) {
  return *reinterpret_cast<float16 *>(&x);
}

static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) {
  return *reinterpret_cast<__half *>(&x);
}

CUDA_ATOMIC_WRAPPER(Add, float16) {
  return CUDAFP16ToPDFP16(
      atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val)));
}
#else
121 122 123
CUDA_ATOMIC_WRAPPER(Add, float16) {
  // concrete packed float16 value may exsits in lower or higher 16bits
  // of the 32bits address.
D
dzhwinter 已提交
124 125 126
  uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
      reinterpret_cast<char *>(address) -
      (reinterpret_cast<uintptr_t>(address) & 0x02));
127 128 129 130 131
  float val_f = static_cast<float>(val);
  uint32_t old = *address_as_ui;
  uint32_t sum;
  uint32_t newval;
  uint32_t assumed;
D
dzhwinter 已提交
132
  if (((uintptr_t)address & 0x02) == 0) {
133 134 135 136 137 138
    // the float16 value stay at lower 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
D
dzhwinter 已提交
139
    ret.x = old & 0xFFFFu;
140 141 142 143 144 145 146 147 148 149 150 151
    return ret;
  } else {
    // the float16 value stay at higher 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
    ret.x = old >> 16;
    return ret;
  }
}
D
dangqingqing 已提交
152
#endif
153 154 155

// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up.
156 157 158 159 160 161 162
template <typename T,
          typename std::enable_if<
              std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *tensor,
                                              size_t index,
                                              const size_t numel,
                                              T value) {
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
#if ((CUDA_VERSION < 10000) || \
     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
  CudaAtomicAdd(reinterpret_cast<platform::float16 *>(tensor) + index,
                static_cast<platform::float16>(value));
#else
  // whether the address is 32-byte aligned.
  __half *target_addr = reinterpret_cast<__half *>(tensor + index);
  bool aligned_half2 =
      (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);

  if (aligned_half2 && index < (numel - 1)) {
    __half2 value2;
    value2.x = *reinterpret_cast<__half *>(&value);
    value2.y = __int2half_rz(0);
    atomicAdd(reinterpret_cast<__half2 *>(target_addr), value2);

  } else if (!aligned_half2 && index > 0) {
    __half2 value2;
    value2.x = __int2half_rz(0);
    value2.y = *reinterpret_cast<__half *>(&value);
    atomicAdd(reinterpret_cast<__half2 *>(target_addr - 1), value2);

  } else {
    atomicAdd(reinterpret_cast<__half *>(tensor) + index,
              *reinterpret_cast<__half *>(&value));
  }
#endif
}

192 193 194 195 196 197 198
template <typename T,
          typename std::enable_if<
              !std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void fastAtomicAdd(T *arr,
                                              size_t index,
                                              const size_t numel,
                                              T value) {
199 200 201 202 203 204 205 206 207
  CudaAtomicAdd(arr + index, value);
}

#ifdef PADDLE_WITH_CUDA
/*
 * One thead block deals with elementwise atomicAdd for vector of len.
 * @in: [x1, x2, x3, ...]
 * @out:[y1+x1, y2+x2, y3+x3, ...]
 * */
208 209 210
template <typename T,
          typename std::enable_if<
              !std::is_same<platform::float16, T>::value>::type * = nullptr>
211 212 213 214 215 216 217 218
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
    const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
  for (int i = tid; i < len; i += threads_per_block) {
    CudaAtomicAdd(&out[i], in[i]);
  }
}

// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
219 220 221
template <typename T,
          typename std::enable_if<
              std::is_same<platform::float16, T>::value>::type * = nullptr>
222 223
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
    const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
224 225 226 227 228 229
#if ((CUDA_VERSION < 10000) || \
     (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
  for (int i = tid; i < len; i += threads_per_block) {
    CudaAtomicAdd(&out[i], in[i]);
  }
#else
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  int i = 0;
  int loops = len / 2 * 2;

  bool aligned_half2 =
      (reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);

  if (aligned_half2) {
    for (i = tid * 2; i < loops; i += threads_per_block * 2) {
      __half2 value2;
      T value_1 = in[i];
      T value_2 = in[i + 1];
      value2.x = *reinterpret_cast<__half *>(&value_1);
      value2.y = *reinterpret_cast<__half *>(&value_2);
      atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
    }
    for (; i < len; i += threads_per_block) {
      fastAtomicAdd(out, i, len, in[i]);
    }
  } else {
    for (int i = tid; i < len; i += threads_per_block) {
      fastAtomicAdd(out, i, len, in[i]);
    }
  }
253
#endif
254 255
}
#endif
256
#endif
257

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
  bfloat16 low_half;
  // the bfloat16 in lower 16bits
  low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
  low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x);
  return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
  bfloat16 high_half;
  // the bfloat16 in higher 16bits
  high_half.x = static_cast<uint16_t>(val >> 16);
  high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x);
  return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) {
  return *reinterpret_cast<bfloat16 *>(&x);
}

static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) {
  return *reinterpret_cast<__nv_bfloat16 *>(&x);
}

CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
  return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
                                    PDBF16ToCUDABF16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
  // concrete packed bfloat16 value may exsits in lower or higher 16bits
  // of the 32bits address.
  uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
      reinterpret_cast<char *>(address) -
      (reinterpret_cast<uintptr_t>(address) & 0x02));
  float val_f = static_cast<float>(val);
  uint32_t old = *address_as_ui;
  uint32_t sum;
  uint32_t newval;
  uint32_t assumed;
  if (((uintptr_t)address & 0x02) == 0) {
    // the bfloat16 value stay at lower 16 bits of the address.
    do {
      assumed = old;
304 305
      old = atomicCAS(
          address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f));
306 307 308 309 310 311 312 313
    } while (old != assumed);
    bfloat16 ret;
    ret.x = old & 0xFFFFu;
    return ret;
  } else {
    // the bfloat16 value stay at higher 16 bits of the address.
    do {
      assumed = old;
314 315
      old = atomicCAS(
          address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f));
316 317 318 319 320 321 322 323
    } while (old != assumed);
    bfloat16 ret;
    ret.x = old >> 16;
    return ret;
  }
}
#endif

324
CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
325 326
  float *real = reinterpret_cast<float *>(address);
  float *imag = real + 1;
327 328
  return complex<float>(CudaAtomicAdd(real, val.real),
                        CudaAtomicAdd(imag, val.imag));
329 330
}

331
CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
332 333
  double *real = reinterpret_cast<double *>(address);
  double *imag = real + 1;
334 335
  return complex<double>(CudaAtomicAdd(real, val.real),
                         CudaAtomicAdd(imag, val.imag));
336 337
}

338 339 340 341 342
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
343
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
344
USE_CUDA_ATOMIC(Max, unsigned long long int);  // NOLINT
345
#else
346
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) {  // NOLINT
347
  if (*address >= val) {
348
    return *address;
349 350
  }

351
  unsigned long long int old = *address, assumed;  // NOLINT
352 353 354 355 356 357 358 359 360 361 362

  do {
    assumed = old;
    if (assumed >= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
363 364 365 366 367

CUDA_ATOMIC_WRAPPER(Max, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
368 369 370 371 372 373 374 375 376 377 378
  long long int res = *address;  // NOLINT
  while (val > res) {
    long long int old = res;                                           // NOLINT
    res = (long long int)atomicCAS((unsigned long long int *)address,  // NOLINT
                                   (unsigned long long int)old,        // NOLINT
                                   (unsigned long long int)val);       // NOLINT
    if (res == old) {
      break;
    }
  }
  return res;
379 380 381 382
}

CUDA_ATOMIC_WRAPPER(Max, float) {
  if (*address >= val) {
383
    return *address;
384 385
  }

386
  int *const address_as_i = reinterpret_cast<int *>(address);
387 388 389 390 391 392 393 394 395 396
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
397 398

  return __int_as_float(old);
399 400 401 402
}

CUDA_ATOMIC_WRAPPER(Max, double) {
  if (*address >= val) {
403
    return *address;
404 405
  }

406 407 408
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
409 410 411 412 413 414 415 416 417

  do {
    assumed = old;
    if (__longlong_as_double(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
418 419

  return __longlong_as_double(old);
420 421 422 423 424 425 426
}

// For atomicMin
USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
427
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
428
USE_CUDA_ATOMIC(Min, unsigned long long int);  // NOLINT
429
#else
430
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) {  // NOLINT
431
  if (*address <= val) {
432
    return *address;
433 434
  }

435
  unsigned long long int old = *address, assumed;  // NOLINT
436 437 438 439 440 441 442 443 444 445 446

  do {
    assumed = old;
    if (assumed <= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
447 448 449 450 451

CUDA_ATOMIC_WRAPPER(Min, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
452 453 454 455 456 457 458 459 460 461 462
  long long int res = *address;  // NOLINT
  while (val < res) {
    long long int old = res;                                           // NOLINT
    res = (long long int)atomicCAS((unsigned long long int *)address,  // NOLINT
                                   (unsigned long long int)old,        // NOLINT
                                   (unsigned long long int)val);       // NOLINT
    if (res == old) {
      break;
    }
  }
  return res;
463 464 465 466
}

CUDA_ATOMIC_WRAPPER(Min, float) {
  if (*address <= val) {
467
    return *address;
468 469
  }

470
  int *const address_as_i = reinterpret_cast<int *>(address);
471 472 473 474 475 476 477 478 479 480
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
481 482

  return __int_as_float(old);
483 484 485 486
}

CUDA_ATOMIC_WRAPPER(Min, double) {
  if (*address <= val) {
487
    return *address;
488 489
  }

490 491 492
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
493 494 495 496 497 498 499 500 501

  do {
    assumed = old;
    if (__longlong_as_double(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
502 503

  return __longlong_as_double(old);
504 505
}

506 507
}  // namespace platform
}  // namespace paddle