未验证 提交 a5ca2672 编写于 作者: C chenxujun 提交者: GitHub

Fix the type conflicts against the openblas (#52187)

上级 ad01eccd
...@@ -28,9 +28,6 @@ limitations under the License. */ ...@@ -28,9 +28,6 @@ limitations under the License. */
template <typename T> template <typename T>
using complex = phi::dtype::complex<T>; using complex = phi::dtype::complex<T>;
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
namespace phi { namespace phi {
#define CUDA_ATOMIC_WRAPPER(op, T) \ #define CUDA_ATOMIC_WRAPPER(op, T) \
...@@ -94,36 +91,39 @@ CUDA_ATOMIC_WRAPPER(Add, double) { ...@@ -94,36 +91,39 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
// convert the value into float and do the add arithmetic. // convert the value into float and do the add arithmetic.
// then store the result into a uint32. // then store the result into a uint32.
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) { inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
float16 low_half; phi::dtype::float16 low_half;
// the float16 in lower 16bits // the float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(static_cast<float>(low_half) + x); low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x; return (val & 0xFFFF0000u) | low_half.x;
} }
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) { inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
float16 high_half; phi::dtype::float16 high_half;
// the float16 in higher 16bits // the float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16); high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(static_cast<float>(high_half) + x); high_half =
static_cast<phi::dtype::float16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
} }
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static __device__ __forceinline__ float16 CUDAFP16ToPDFP16(__half x) { static __device__ __forceinline__ phi::dtype::float16 CUDAFP16ToPDFP16(
return *reinterpret_cast<float16 *>(&x); __half x) {
return *reinterpret_cast<phi::dtype::float16 *>(&x);
} }
static __device__ __forceinline__ __half PDFP16ToCUDAFP16(float16 x) { static __device__ __forceinline__ __half
PDFP16ToCUDAFP16(phi::dtype::float16 x) {
return *reinterpret_cast<__half *>(&x); return *reinterpret_cast<__half *>(&x);
} }
CUDA_ATOMIC_WRAPPER(Add, float16) { CUDA_ATOMIC_WRAPPER(Add, phi::dtype::float16) {
return CUDAFP16ToPDFP16( return CUDAFP16ToPDFP16(
atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val))); atomicAdd(reinterpret_cast<__half *>(address), PDFP16ToCUDAFP16(val)));
} }
#else #else
CUDA_ATOMIC_WRAPPER(Add, float16) { CUDA_ATOMIC_WRAPPER(Add, phi::dtype::float16) {
// concrete packed float16 value may exsits in lower or higher 16bits // concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address. // of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>( uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
...@@ -140,7 +140,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { ...@@ -140,7 +140,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old & 0xFFFFu; ret.x = old & 0xFFFFu;
return ret; return ret;
} else { } else {
...@@ -149,7 +149,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { ...@@ -149,7 +149,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old >> 16; ret.x = old >> 16;
return ret; return ret;
} }
...@@ -168,14 +168,17 @@ struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {}; ...@@ -168,14 +168,17 @@ struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {};
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
template <> template <>
struct VecAtomicAddHelper<float16> struct VecAtomicAddHelper<phi::dtype::float16>
: VecAtomicAddHelperBase<float16, true, __half, __half2> {}; : VecAtomicAddHelperBase<phi::dtype::float16, true, __half, __half2> {};
#endif #endif
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <> template <>
struct VecAtomicAddHelper<bfloat16> struct VecAtomicAddHelper<phi::dtype::bfloat16>
: VecAtomicAddHelperBase<bfloat16, true, __nv_bfloat16, __nv_bfloat162> {}; : VecAtomicAddHelperBase<phi::dtype::bfloat16,
true,
__nv_bfloat16,
__nv_bfloat162> {};
#endif #endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )" // The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
...@@ -225,36 +228,40 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr, ...@@ -225,36 +228,40 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) { inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
bfloat16 low_half; phi::dtype::bfloat16 low_half;
// the bfloat16 in lower 16bits // the bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x); low_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x; return (val & 0xFFFF0000u) | low_half.x;
} }
inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) { inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
bfloat16 high_half; phi::dtype::bfloat16 high_half;
// the bfloat16 in higher 16bits // the bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16); high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x); high_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
} }
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) { static __device__ __forceinline__ phi::dtype::bfloat16 CUDABF16ToPDBF16(
return *reinterpret_cast<bfloat16 *>(&x); __nv_bfloat16 x) {
return *reinterpret_cast<phi::dtype::bfloat16 *>(&x);
} }
static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) { static __device__ __forceinline__ __nv_bfloat16
PDBF16ToCUDABF16(phi::dtype::bfloat16 x) {
return *reinterpret_cast<__nv_bfloat16 *>(&x); return *reinterpret_cast<__nv_bfloat16 *>(&x);
} }
CUDA_ATOMIC_WRAPPER(Add, bfloat16) { CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) {
return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
PDBF16ToCUDABF16(val))); PDBF16ToCUDABF16(val)));
} }
#else #else
CUDA_ATOMIC_WRAPPER(Add, bfloat16) { CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) {
// concrete packed bfloat16 value may exsits in lower or higher 16bits // concrete packed bfloat16 value may exsits in lower or higher 16bits
// of the 32bits address. // of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>( uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
...@@ -272,7 +279,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) { ...@@ -272,7 +279,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
old = atomicCAS( old = atomicCAS(
address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f)); address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
bfloat16 ret; phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu; ret.x = old & 0xFFFFu;
return ret; return ret;
} else { } else {
...@@ -282,7 +289,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) { ...@@ -282,7 +289,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
old = atomicCAS( old = atomicCAS(
address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f)); address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
bfloat16 ret; phi::dtype::bfloat16 ret;
ret.x = old >> 16; ret.x = old >> 16;
return ret; return ret;
} }
...@@ -389,22 +396,24 @@ CUDA_ATOMIC_WRAPPER(Max, double) { ...@@ -389,22 +396,24 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) { inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) {
float16 low_half; phi::dtype::float16 low_half;
// The float16 in lower 16bits // The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(max(static_cast<float>(low_half), x)); low_half =
static_cast<phi::dtype::float16>(max(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x; return (val & 0xFFFF0000u) | low_half.x;
} }
inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) { inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) {
float16 high_half; phi::dtype::float16 high_half;
// The float16 in higher 16bits // The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16); high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(max(static_cast<float>(high_half), x)); high_half =
static_cast<phi::dtype::float16>(max(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
} }
CUDA_ATOMIC_WRAPPER(Max, float16) { CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) {
if (*address >= val) { if (*address >= val) {
return *address; return *address;
} }
...@@ -420,7 +429,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) { ...@@ -420,7 +429,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old & 0xFFFFu; ret.x = old & 0xFFFFu;
return ret; return ret;
} else { } else {
...@@ -429,7 +438,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) { ...@@ -429,7 +438,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old >> 16; ret.x = old >> 16;
return ret; return ret;
} }
...@@ -522,22 +531,24 @@ CUDA_ATOMIC_WRAPPER(Min, double) { ...@@ -522,22 +531,24 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) { inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) {
float16 low_half; phi::dtype::float16 low_half;
// The float16 in lower 16bits // The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu); low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(min(static_cast<float>(low_half), x)); low_half =
static_cast<phi::dtype::float16>(min(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x; return (val & 0xFFFF0000u) | low_half.x;
} }
inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) { inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) {
float16 high_half; phi::dtype::float16 high_half;
// The float16 in higher 16bits // The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16); high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(min(static_cast<float>(high_half), x)); high_half =
static_cast<phi::dtype::float16>(min(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
} }
CUDA_ATOMIC_WRAPPER(Min, float16) { CUDA_ATOMIC_WRAPPER(Min, phi::dtype::float16) {
if (*address <= val) { if (*address <= val) {
return *address; return *address;
} }
...@@ -553,7 +564,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { ...@@ -553,7 +564,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old & 0xFFFFu; ret.x = old & 0xFFFFu;
return ret; return ret;
} else { } else {
...@@ -562,7 +573,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { ...@@ -562,7 +573,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
assumed = old; assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f)); old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f));
} while (old != assumed); } while (old != assumed);
float16 ret; phi::dtype::float16 ret;
ret.x = old >> 16; ret.x = old >> 16;
return ret; return ret;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册