提交 f2bbbb2b 编写于 作者: K Kexin Zhao

fix arithmetic operator

上级 18d616ed
...@@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) { ...@@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16 #endif // PADDLE_CUDA_FP16
// Arithmetic operators for float16 on GPU // Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(PADDLE_CUDA_FP16)
DEVICE inline float16 operator+(const float16& a, const float16& b) { HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b))); return float16(__hadd(half(a), half(b)));
#else
return float16(float(a) + float(b));
} }
DEVICE inline float16 operator-(const float16& a, const float16& b) { HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b))); return float16(__hsub(half(a), half(b)));
#else
return float16(float(a) - float(b));
} }
DEVICE inline float16 operator*(const float16& a, const float16& b) { HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b))); return float16(__hmul(half(a), half(b)));
#else
return float16(float(a) * float(b));
} }
DEVICE inline float16 operator/(const float16& a, const float16& b) { HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
// TODO(kexinzhao): check the cuda version that starts to support __hdiv #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(half(a)); float num = __half2float(half(a));
float denom = __half2float(half(b)); float denom = __half2float(half(b));
return float16(num / denom); return float16(num / denom);
#else
return float16(float(a) / float(b));
} }
DEVICE inline float16 operator-(const float16& a) { HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(half(a))); return float16(__hneg(half(a)));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
} }
DEVICE inline float16& operator+=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = a + b; a = a + b;
return a; return a;
} }
DEVICE inline float16& operator-=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = a - b; a = a - b;
return a; return a;
} }
DEVICE inline float16& operator*=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = a * b; a = a * b;
return a; return a;
} }
DEVICE inline float16& operator/=(float16& a, const float16& b) { HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = a / b; a = a / b;
return a; return a;
} }
DEVICE inline bool operator==(const float16& a, const float16& b) { HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b)); return __heq(half(a), half(b));
#else
return float(a) == float(b);
} }
DEVICE inline bool operator!=(const float16& a, const float16& b) { HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b)); return __hne(half(a), half(b));
#else
return float(a) != float(b);
} }
DEVICE inline bool operator<(const float16& a, const float16& b) { HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b)); return __hlt(half(a), half(b));
#else
return float(a) < float(b);
} }
DEVICE inline bool operator<=(const float16& a, const float16& b) { HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b)); return __hle(half(a), half(b));
#else
return float(a) <= float(b);
} }
DEVICE inline bool operator>(const float16& a, const float16& b) { HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b)); return __hgt(half(a), half(b));
#else
return float(a) > float(b);
} }
DEVICE inline bool operator>=(const float16& a, const float16& b) { HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b)); return __hge(half(a), half(b));
#else
return float(a) >= float(b);
} }
// Arithmetic operators for float16 on ARMv8.2-A CPU // Arithmetic operators for float16 on ARMv8.2-A CPU
...@@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) { ...@@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
// Arithmetic operators for float16, software emulated on other CPU/GPU // Arithmetic operators for float16, software emulated on other CPU
#else #else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { HOST inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b)); return float16(float(a) + float(b));
} }
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { HOST inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b)); return float16(float(a) - float(b));
} }
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { HOST inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b)); return float16(float(a) * float(b));
} }
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { HOST inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b)); return float16(float(a) / float(b));
} }
HOSTDEVICE inline float16 operator-(const float16& a) { HOST inline float16 operator-(const float16& a) {
float16 res; float16 res;
res.x = a.x ^ 0x8000; res.x = a.x ^ 0x8000;
return res; return res;
} }
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { HOST inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b)); a = float16(float(a) + float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { HOST inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b)); a = float16(float(a) - float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { HOST inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b)); a = float16(float(a) * float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { HOST inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b)); a = float16(float(a) / float(b));
return a; return a;
} }
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { HOST inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b); return float(a) == float(b);
} }
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { HOST inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b); return float(a) != float(b);
} }
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { HOST inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b); return float(a) < float(b);
} }
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { HOST inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b); return float(a) <= float(b);
} }
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { HOST inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b); return float(a) > float(b);
} }
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { HOST inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b); return float(a) >= float(b);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册