From a5ca2672ba240ce475759c0b30a90af1ee01f6fa Mon Sep 17 00:00:00 2001 From: chenxujun Date: Wed, 29 Mar 2023 17:06:25 +0800 Subject: [PATCH] Fix the type conflicts against the openblas (#52187) --- paddle/phi/backends/gpu/gpu_primitives.h | 97 +++++++++++++----------- 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index 67b34aa289f..252ed90e441 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -28,9 +28,6 @@ limitations under the License. */ template using complex = phi::dtype::complex; -using float16 = phi::dtype::float16; -using bfloat16 = phi::dtype::bfloat16; - namespace phi { #define CUDA_ATOMIC_WRAPPER(op, T) \ @@ -94,36 +91,39 @@ CUDA_ATOMIC_WRAPPER(Add, double) { // convert the value into float and do the add arithmetic. // then store the result into a uint32. inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) { - float16 low_half; + phi::dtype::float16 low_half; // the float16 in lower 16bits low_half.x = static_cast(val & 0xFFFFu); - low_half = static_cast(static_cast(low_half) + x); + low_half = static_cast(static_cast(low_half) + x); return (val & 0xFFFF0000u) | low_half.x; } inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) { - float16 high_half; + phi::dtype::float16 high_half; // the float16 in higher 16bits high_half.x = static_cast(val >> 16); - high_half = static_cast(static_cast(high_half) + x); + high_half = + static_cast(static_cast(high_half) + x); return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } #if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 -static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) { - return *reinterpret_cast(&x); +static __device__ __forceinline__ phi::dtype::float16 CUDAFP16ToPDFP16( + __half x) { + return *reinterpret_cast(&x); } -static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) { +static __device__ __forceinline__ __half +PDFP16ToCUDAFP16(phi::dtype::float16 x) { return *reinterpret_cast<__half *>(&x); } -CUDA_ATOMIC_WRAPPER(Add, float16) { +CUDA_ATOMIC_WRAPPER(Add, phi::dtype::float16) { return CUDAFP16ToPDFP16( atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val))); } #else -CUDA_ATOMIC_WRAPPER(Add, float16) { +CUDA_ATOMIC_WRAPPER(Add, phi::dtype::float16) { // concrete packed float16 value may exsits in lower or higher 16bits // of the 32bits address. uint32_t *address_as_ui = reinterpret_cast( @@ -140,7 +140,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old & 0xFFFFu; return ret; } else { @@ -149,7 +149,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old >> 16; return ret; } @@ -168,14 +168,17 @@ struct VecAtomicAddHelper : VecAtomicAddHelperBase {}; #if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 template <> -struct VecAtomicAddHelper - : VecAtomicAddHelperBase {}; +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; #endif #if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template <> -struct VecAtomicAddHelper - : VecAtomicAddHelperBase {}; +struct VecAtomicAddHelper + : VecAtomicAddHelperBase {}; #endif // The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )" @@ -225,36 +228,40 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr, // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) { - bfloat16 low_half; + phi::dtype::bfloat16 low_half; // the bfloat16 in lower 16bits low_half.x = static_cast(val & 0xFFFFu); - low_half = static_cast(static_cast(low_half) + x); + low_half = + static_cast(static_cast(low_half) + x); return (val & 0xFFFF0000u) | low_half.x; } inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) { - bfloat16 high_half; + phi::dtype::bfloat16 high_half; // the bfloat16 in higher 16bits high_half.x = static_cast(val >> 16); - high_half = static_cast(static_cast(high_half) + x); + high_half = + static_cast(static_cast(high_half) + x); return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } #if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) { - return *reinterpret_cast(&x); +static __device__ __forceinline__ phi::dtype::bfloat16 CUDABF16ToPDBF16( + __nv_bfloat16 x) { + return *reinterpret_cast(&x); } -static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) { +static __device__ __forceinline__ __nv_bfloat16 +PDBF16ToCUDABF16(phi::dtype::bfloat16 x) { return *reinterpret_cast<__nv_bfloat16 *>(&x); } -CUDA_ATOMIC_WRAPPER(Add, bfloat16) { +CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) { return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), PDBF16ToCUDABF16(val))); } #else -CUDA_ATOMIC_WRAPPER(Add, bfloat16) { +CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) { // concrete packed bfloat16 value may exsits in lower or higher 16bits // of the 32bits address. uint32_t *address_as_ui = reinterpret_cast( @@ -272,7 +279,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) { old = atomicCAS( address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f)); } while (old != assumed); - bfloat16 ret; + phi::dtype::bfloat16 ret; ret.x = old & 0xFFFFu; return ret; } else { @@ -282,7 +289,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) { old = atomicCAS( address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f)); } while (old != assumed); - bfloat16 ret; + phi::dtype::bfloat16 ret; ret.x = old >> 16; return ret; } @@ -389,22 +396,24 @@ CUDA_ATOMIC_WRAPPER(Max, double) { #ifdef PADDLE_CUDA_FP16 inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) { - float16 low_half; + phi::dtype::float16 low_half; // The float16 in lower 16bits low_half.x = static_cast(val & 0xFFFFu); - low_half = static_cast(max(static_cast(low_half), x)); + low_half = + static_cast(max(static_cast(low_half), x)); return (val & 0xFFFF0000u) | low_half.x; } inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) { - float16 high_half; + phi::dtype::float16 high_half; // The float16 in higher 16bits high_half.x = static_cast(val >> 16); - high_half = static_cast(max(static_cast(high_half), x)); + high_half = + static_cast(max(static_cast(high_half), x)); return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } -CUDA_ATOMIC_WRAPPER(Max, float16) { +CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) { if (*address >= val) { return *address; } @@ -420,7 +429,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old & 0xFFFFu; return ret; } else { @@ -429,7 +438,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old >> 16; return ret; } @@ -522,22 +531,24 @@ CUDA_ATOMIC_WRAPPER(Min, double) { #ifdef PADDLE_CUDA_FP16 inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) { - float16 low_half; + phi::dtype::float16 low_half; // The float16 in lower 16bits low_half.x = static_cast(val & 0xFFFFu); - low_half = static_cast(min(static_cast(low_half), x)); + low_half = + static_cast(min(static_cast(low_half), x)); return (val & 0xFFFF0000u) | low_half.x; } inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) { - float16 high_half; + phi::dtype::float16 high_half; // The float16 in higher 16bits high_half.x = static_cast(val >> 16); - high_half = static_cast(min(static_cast(high_half), x)); + high_half = + static_cast(min(static_cast(high_half), x)); return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } -CUDA_ATOMIC_WRAPPER(Min, float16) { +CUDA_ATOMIC_WRAPPER(Min, phi::dtype::float16) { if (*address <= val) { return *address; } @@ -553,7 +564,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old & 0xFFFFu; return ret; } else { @@ -562,7 +573,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { assumed = old; old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f)); } while (old != assumed); - float16 ret; + phi::dtype::float16 ret; ret.x = old >> 16; return ret; } -- GitLab