cuda_helper.h 3.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cuda.h>

namespace paddle {
namespace platform {

#define CUDA_ATOMIC_WRAPPER(op, T) \
  __device__ __forceinline__ T CudaAtomic##op(T* address, const T val)

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

27 28 29 30 31
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
//                    to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

32 33
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
Y
Yu Yang 已提交
34 35 36 37 38 39 40 41 42 43
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
USE_CUDA_ATOMIC(Add, unsigned long long int);

CUDA_ATOMIC_WRAPPER(Add, int64_t) {
  static_assert(sizeof(int64_t) == sizeof(long long int),
                "long long should be int64");
  return CudaAtomicAdd(reinterpret_cast<unsigned long long int*>(address),
                       static_cast<unsigned long long int>(val));
}
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
  unsigned long long int* address_as_ull =
      reinterpret_cast<unsigned long long int*>(address);
  unsigned long long int old = *address_as_ull, assumed;

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}
D
dangqingqing 已提交
63
#endif
64

C
chengduoZH 已提交
65 66 67 68 69 70
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
  return __shfl_down(val, delta);
}
C
chengduoZH 已提交
71
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
C
chengduoZH 已提交
72 73 74
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
C
chengduoZH 已提交
75
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
C
chengduoZH 已提交
76 77 78
#endif

template <typename T>
C
chengduoZH 已提交
79
__device__ T reduceSum(T val, int tid, int len) {
C
chengduoZH 已提交
80 81
  __shared__ T shm[32];
  const int warpSize = 32;
C
chengduoZH 已提交
82 83 84
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

C
chengduoZH 已提交
85
  for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
86
    val += __shfl_down_sync(mask, val, offset);
C
chengduoZH 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

  if (tid < warpSize) shm[tid] = 0;

  __syncthreads();

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
      val += __shfl_down_sync(mask, val, offset);
  }

  return val;
}

107 108
}  // namespace platform
}  // namespace paddle