gpu_primitives.h 9.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#ifdef PADDLE_WITH_CUDA
17
#include <cuda.h>
18 19 20 21
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
22
#include <stdio.h>
23
#include "paddle/fluid/platform/complex.h"
24
#include "paddle/fluid/platform/float16.h"
25 26 27 28 29

namespace paddle {
namespace platform {

#define CUDA_ATOMIC_WRAPPER(op, T) \
30
  __device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
31 32 33 34

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

35 36 37 38 39
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
//                    to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

40 41
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
Y
Yu Yang 已提交
42 43
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
Y
Yu Yang 已提交
44 45 46
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int);  // NOLINT
Y
Yu Yang 已提交
47 48

CUDA_ATOMIC_WRAPPER(Add, int64_t) {
Y
Yu Yang 已提交
49 50
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
Y
Yu Yang 已提交
51
                "long long should be int64");
Y
Yu Yang 已提交
52
  return CudaAtomicAdd(
53 54
      reinterpret_cast<unsigned long long int *>(address),  // NOLINT
      static_cast<unsigned long long int>(val));            // NOLINT
Y
Yu Yang 已提交
55
}
56

57
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600)
58 59 60
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
61 62 63
  unsigned long long int *address_as_ull =                  // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
64 65 66 67 68 69 70 71 72 73 74

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}
75 76 77 78 79 80 81 82 83 84 85 86 87
#endif

#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.

// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
D
dzhwinter 已提交
88
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
89 90
  float16 low_half;
  // the float16 in lower 16bits
D
dzhwinter 已提交
91
  low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
92
  low_half = static_cast<float16>(static_cast<float>(low_half) + x);
D
dzhwinter 已提交
93
  return (val & 0xFFFF0000u) | low_half.x;
94 95
}

D
dzhwinter 已提交
96
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
97 98 99 100
  float16 high_half;
  // the float16 in higher 16bits
  high_half.x = static_cast<uint16_t>(val >> 16);
  high_half = static_cast<float16>(static_cast<float>(high_half) + x);
D
dzhwinter 已提交
101
  return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
102 103 104 105 106
}

CUDA_ATOMIC_WRAPPER(Add, float16) {
  // concrete packed float16 value may exsits in lower or higher 16bits
  // of the 32bits address.
D
dzhwinter 已提交
107 108 109
  uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
      reinterpret_cast<char *>(address) -
      (reinterpret_cast<uintptr_t>(address) & 0x02));
110 111 112 113 114
  float val_f = static_cast<float>(val);
  uint32_t old = *address_as_ui;
  uint32_t sum;
  uint32_t newval;
  uint32_t assumed;
D
dzhwinter 已提交
115
  if (((uintptr_t)address & 0x02) == 0) {
116 117 118 119 120 121
    // the float16 value stay at lower 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
D
dzhwinter 已提交
122
    ret.x = old & 0xFFFFu;
123 124 125 126 127 128 129 130 131 132 133 134
    return ret;
  } else {
    // the float16 value stay at higher 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
    ret.x = old >> 16;
    return ret;
  }
}
D
dangqingqing 已提交
135
#endif
136

137
CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
138 139
  float *real = reinterpret_cast<float *>(address);
  float *imag = real + 1;
140 141
  return complex<float>(CudaAtomicAdd(real, val.real),
                        CudaAtomicAdd(imag, val.imag));
142 143
}

144
CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
145 146
  double *real = reinterpret_cast<double *>(address);
  double *imag = real + 1;
147 148
  return complex<double>(CudaAtomicAdd(real, val.real),
                         CudaAtomicAdd(imag, val.imag));
149 150
}

151 152 153 154 155
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
156
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
157
USE_CUDA_ATOMIC(Max, unsigned long long int);  // NOLINT
158
#else
159
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) {  // NOLINT
160
  if (*address >= val) {
161
    return *address;
162 163
  }

164
  unsigned long long int old = *address, assumed;  // NOLINT
165 166 167 168 169 170 171 172 173 174 175

  do {
    assumed = old;
    if (assumed >= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
176 177 178 179 180

CUDA_ATOMIC_WRAPPER(Max, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
181 182 183 184 185 186 187 188 189 190 191
  long long int res = *address;  // NOLINT
  while (val > res) {
    long long int old = res;                                           // NOLINT
    res = (long long int)atomicCAS((unsigned long long int *)address,  // NOLINT
                                   (unsigned long long int)old,        // NOLINT
                                   (unsigned long long int)val);       // NOLINT
    if (res == old) {
      break;
    }
  }
  return res;
192 193 194 195
}

CUDA_ATOMIC_WRAPPER(Max, float) {
  if (*address >= val) {
196
    return *address;
197 198
  }

199
  int *const address_as_i = reinterpret_cast<int *>(address);
200 201 202 203 204 205 206 207 208 209
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
210 211

  return __int_as_float(old);
212 213 214 215
}

CUDA_ATOMIC_WRAPPER(Max, double) {
  if (*address >= val) {
216
    return *address;
217 218
  }

219 220 221
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
222 223 224 225 226 227 228 229 230

  do {
    assumed = old;
    if (__longlong_as_double(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
231 232

  return __longlong_as_double(old);
233 234 235 236 237 238 239
}

// For atomicMin
USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
240
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350)
241
USE_CUDA_ATOMIC(Min, unsigned long long int);  // NOLINT
242
#else
243
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) {  // NOLINT
244
  if (*address <= val) {
245
    return *address;
246 247
  }

248
  unsigned long long int old = *address, assumed;  // NOLINT
249 250 251 252 253 254 255 256 257 258 259

  do {
    assumed = old;
    if (assumed <= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
260 261 262 263 264

CUDA_ATOMIC_WRAPPER(Min, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
265 266 267 268 269 270 271 272 273 274 275
  long long int res = *address;  // NOLINT
  while (val < res) {
    long long int old = res;                                           // NOLINT
    res = (long long int)atomicCAS((unsigned long long int *)address,  // NOLINT
                                   (unsigned long long int)old,        // NOLINT
                                   (unsigned long long int)val);       // NOLINT
    if (res == old) {
      break;
    }
  }
  return res;
276 277 278 279
}

CUDA_ATOMIC_WRAPPER(Min, float) {
  if (*address <= val) {
280
    return *address;
281 282
  }

283
  int *const address_as_i = reinterpret_cast<int *>(address);
284 285 286 287 288 289 290 291 292 293
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
294 295

  return __int_as_float(old);
296 297 298 299
}

CUDA_ATOMIC_WRAPPER(Min, double) {
  if (*address <= val) {
300
    return *address;
301 302
  }

303 304 305
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
306 307 308 309 310 311 312 313 314

  do {
    assumed = old;
    if (__longlong_as_double(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
315 316

  return __longlong_as_double(old);
317 318
}

319 320
}  // namespace platform
}  // namespace paddle