diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index d443e78ed874f3783c747d45bef60c78b4b31f07..3e070da546b2ae85c40bb0e9cae05cc30d6d22c1 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -101,6 +101,20 @@ inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) { return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } +#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 +static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) { + return *reinterpret_cast(&x); +} + +static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) { + return *reinterpret_cast<__half *>(&x); +} + +CUDA_ATOMIC_WRAPPER(Add, float16) { + return CUDAFP16ToPDFP16( + atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val))); +} +#else CUDA_ATOMIC_WRAPPER(Add, float16) { // concrete packed float16 value may exsits in lower or higher 16bits // of the 32bits address. @@ -133,6 +147,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { } } #endif +#endif CUDA_ATOMIC_WRAPPER(Add, complex) { float *real = reinterpret_cast(address);