cuda_helper.h 2.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cuda.h>

namespace paddle {
namespace platform {

#define CUDA_ATOMIC_WRAPPER(op, T) \
  __device__ __forceinline__ T CudaAtomic##op(T* address, const T val)

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

27 28 29 30 31
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
//                    to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

32 33
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
Y
Yu Yang 已提交
34 35
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
Y
Yu Yang 已提交
36 37 38
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int);  // NOLINT
Y
Yu Yang 已提交
39 40

CUDA_ATOMIC_WRAPPER(Add, int64_t) {
Y
Yu Yang 已提交
41 42
  // Here, we check long long int must be int64_t.
  static_assert(sizeof(int64_t) == sizeof(long long int),  // NOLINT
Y
Yu Yang 已提交
43
                "long long should be int64");
Y
Yu Yang 已提交
44 45 46
  return CudaAtomicAdd(
      reinterpret_cast<unsigned long long int*>(address),  // NOLINT
      static_cast<unsigned long long int>(val));           // NOLINT
Y
Yu Yang 已提交
47
}
48 49 50 51 52

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
Y
Yu Yang 已提交
53 54 55
  unsigned long long int* address_as_ull =                 // NOLINT
      reinterpret_cast<unsigned long long int*>(address);  // NOLINT
  unsigned long long int old = *address_as_ull, assumed;   // NOLINT
56 57 58 59 60 61 62 63 64 65 66

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}
D
dangqingqing 已提交
67
#endif
68 69 70

}  // namespace platform
}  // namespace paddle