cuda_primitives.h 8.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cuda.h>
17
#include <stdio.h>
18 19
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
20
#include "paddle/fluid/platform/float16.h"
21 22 23 24 25

namespace paddle {
namespace platform {

#define CUDA_ATOMIC_WRAPPER(op, T) \
26
  __device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
27 28 29 30

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

31 32 33 34 35
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
//                    to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

36 37
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
Y
Yu Yang 已提交
38 39
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
Y
Yu Yang 已提交
40 41 42
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int);  // NOLINT
Y
Yu Yang 已提交
43 44

CUDA_ATOMIC_WRAPPER(Add, int64_t) {
Y
Yu Yang 已提交
45 46
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
Y
Yu Yang 已提交
47
                "long long should be int64");
Y
Yu Yang 已提交
48
  return CudaAtomicAdd(
49 50
      reinterpret_cast<unsigned long long int *>(address),  // NOLINT
      static_cast<unsigned long long int>(val));            // NOLINT
Y
Yu Yang 已提交
51
}
52 53 54 55 56

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
57 58 59
  unsigned long long int *address_as_ull =                  // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
60 61 62 63 64 65 66 67 68 69 70

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}
71 72 73 74 75 76 77 78 79 80 81 82 83
#endif

#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.

// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
D
dzhwinter 已提交
84
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
85 86
  float16 low_half;
  // the float16 in lower 16bits
D
dzhwinter 已提交
87
  low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
88
  low_half = static_cast<float16>(static_cast<float>(low_half) + x);
D
dzhwinter 已提交
89
  return (val & 0xFFFF0000u) | low_half.x;
90 91
}

D
dzhwinter 已提交
92
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
93 94 95 96
  float16 high_half;
  // the float16 in higher 16bits
  high_half.x = static_cast<uint16_t>(val >> 16);
  high_half = static_cast<float16>(static_cast<float>(high_half) + x);
D
dzhwinter 已提交
97
  return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
98 99 100 101 102
}

CUDA_ATOMIC_WRAPPER(Add, float16) {
  // concrete packed float16 value may exsits in lower or higher 16bits
  // of the 32bits address.
D
dzhwinter 已提交
103 104 105
  uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
      reinterpret_cast<char *>(address) -
      (reinterpret_cast<uintptr_t>(address) & 0x02));
106 107 108 109 110
  float val_f = static_cast<float>(val);
  uint32_t old = *address_as_ui;
  uint32_t sum;
  uint32_t newval;
  uint32_t assumed;
D
dzhwinter 已提交
111
  if (((uintptr_t)address & 0x02) == 0) {
112 113 114 115 116 117
    // the float16 value stay at lower 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
D
dzhwinter 已提交
118
    ret.x = old & 0xFFFFu;
119 120 121 122 123 124 125 126 127 128 129 130
    return ret;
  } else {
    // the float16 value stay at higher 16 bits of the address.
    do {
      assumed = old;
      old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
    } while (old != assumed);
    float16 ret;
    ret.x = old >> 16;
    return ret;
  }
}
D
dangqingqing 已提交
131
#endif
132

133 134 135 136 137 138 139 140 141 142 143 144 145 146
CUDA_ATOMIC_WRAPPER(Add, complex64) {
  float *real = reinterpret_cast<float *>(address);
  float *imag = real + 1;
  return complex64(CudaAtomicAdd(real, val.real),
                   CudaAtomicAdd(imag, val.imag));
}

CUDA_ATOMIC_WRAPPER(Add, complex128) {
  double *real = reinterpret_cast<double *>(address);
  double *imag = real + 1;
  return complex128(CudaAtomicAdd(real, val.real),
                    CudaAtomicAdd(imag, val.imag));
}

147 148 149 150 151
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
152
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
153
USE_CUDA_ATOMIC(Max, unsigned long long int);  // NOLINT
154
#else
155
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) {  // NOLINT
156 157 158 159
  if (*address >= val) {
    return;
  }

160
  unsigned long long int old = *address, assumed;  // NOLINT
161 162 163 164 165 166 167 168 169 170 171

  do {
    assumed = old;
    if (assumed >= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186

CUDA_ATOMIC_WRAPPER(Max, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
  return CudaAtomicMax(
      reinterpret_cast<unsigned long long int *>(address),  // NOLINT
      static_cast<unsigned long long int>(val));            // NOLINT
}

CUDA_ATOMIC_WRAPPER(Max, float) {
  if (*address >= val) {
    return;
  }

187
  int *const address_as_i = reinterpret_cast<int *>(address);
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
}

CUDA_ATOMIC_WRAPPER(Max, double) {
  if (*address >= val) {
    return;
  }

205 206 207
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223

  do {
    assumed = old;
    if (__longlong_as_double(assumed) >= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
}

// For atomicMin
USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
224
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
225
USE_CUDA_ATOMIC(Min, unsigned long long int);  // NOLINT
226
#else
227
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) {  // NOLINT
228 229 230 231
  if (*address <= val) {
    return;
  }

232
  unsigned long long int old = *address, assumed;  // NOLINT
233 234 235 236 237 238 239 240 241 242 243

  do {
    assumed = old;
    if (assumed <= val) {
      break;
    }

    old = atomicCAS(address, assumed, val);
  } while (assumed != old);
}
#endif
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258

CUDA_ATOMIC_WRAPPER(Min, int64_t) {
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
                "long long should be int64");
  return CudaAtomicMin(
      reinterpret_cast<unsigned long long int *>(address),  // NOLINT
      static_cast<unsigned long long int>(val));            // NOLINT
}

CUDA_ATOMIC_WRAPPER(Min, float) {
  if (*address <= val) {
    return;
  }

259
  int *const address_as_i = reinterpret_cast<int *>(address);
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  int old = *address_as_i, assumed;

  do {
    assumed = old;
    if (__int_as_float(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_i, assumed, __float_as_int(val));
  } while (assumed != old);
}

CUDA_ATOMIC_WRAPPER(Min, double) {
  if (*address <= val) {
    return;
  }

277 278 279
  unsigned long long int *const address_as_ull =            // NOLINT
      reinterpret_cast<unsigned long long int *>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;    // NOLINT
280 281 282 283 284 285 286 287 288 289 290

  do {
    assumed = old;
    if (__longlong_as_double(assumed) <= val) {
      break;
    }

    old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val));
  } while (assumed != old);
}

291 292
}  // namespace platform
}  // namespace paddle