From ef96ffb6bc9930cc48b37c29e688f07c0cab5a3a Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Fri, 25 Feb 2022 10:31:36 +0800 Subject: [PATCH] [Fix bug] fix fp16 atomicAdd compiler error on different cuda_arch. (#39886) * Fix compile error on cuda_arch less than 700. --- paddle/fluid/platform/device/gpu/gpu_primitives.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 8616e969f69..8aec8e840f3 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -210,6 +210,12 @@ template ::value>::type * = nullptr> __device__ __forceinline__ void VectorizedAtomicAddPerBlock( const int64_t len, int tid, int threads_per_block, const T *in, T *out) { +#if ((CUDA_VERSION < 10000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + for (int i = tid; i < len; i += threads_per_block) { + CudaAtomicAdd(&out[i], in[i]); + } +#else int i = 0; int loops = len / 2 * 2; @@ -233,6 +239,7 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock( fastAtomicAdd(out, i, len, in[i]); } } +#endif } #endif #endif -- GitLab