提交 e877cdb8 编写于 作者: K Kexin Zhao

add float16 arithmetic on arm cpu

上级 9d8b3059
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// need to define PADDLE_ARM_FP16
#pragma once #pragma once
#include <cstdint> #include <cstdint>
...@@ -24,6 +26,18 @@ limitations under the License. */ ...@@ -24,6 +26,18 @@ limitations under the License. */
#include "Eigen/src/Core/arch/CUDA/Half.h" #include "Eigen/src/Core/arch/CUDA/Half.h"
#endif #endif
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif // __GNUC__
#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif // __clang__
#ifdef __CUDACC__ #ifdef __CUDACC__
#define PADDLE_HOSTDEVICE __host__ __device__ #define PADDLE_HOSTDEVICE __host__ __device__
#if CUDA_VERSION >= 7050 #if CUDA_VERSION >= 7050
...@@ -48,6 +62,7 @@ limitations under the License. */ ...@@ -48,6 +62,7 @@ limitations under the License. */
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON #define PADDLE_NEON
#include <arm_neon.h>
#endif #endif
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_32) #if defined(PADDLE_NEON) && defined(PADDLE_ARM_32)
...@@ -58,26 +73,16 @@ limitations under the License. */ ...@@ -58,26 +73,16 @@ limitations under the License. */
#define PADDLE_NEON_64 #define PADDLE_NEON_64
#endif #endif
#if defined(PADDLE_ARM) && defined(PADDLE_NEON) #ifdef PADDLE_ARM
#include <arm_neon.h>
#endif
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(PADDLE_ARM)
#include <immintrin.h>
#else
#ifdef __F16C__ #ifdef __F16C__
#undef __F16C__ #undef __F16C__
#endif #endif // __F16C__
#endif #else
#include <immintrin.h>
#endif // PADDLE_ARM
#define PADDLE_ALIGN(x) __attribute__((aligned(x))) #define PADDLE_ALIGN(x) __attribute__((aligned(x)))
// https://github.com/pytorch/pytorch/blob/master/torch/lib/ATen/Half.h
template <typename To, typename From>
To convert(From f) {
return static_cast<To>(f);
}
namespace paddle { namespace paddle {
struct float16; struct float16;
...@@ -86,13 +91,12 @@ namespace fp16_impl { ...@@ -86,13 +91,12 @@ namespace fp16_impl {
// convert from float to half precision in round-to-nearest-even mode // convert from float to half precision in round-to-nearest-even mode
PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f); PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f);
PADDLE_HOSTDEVICE inline float half_to_float(float16 h); PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
PADDLE_HOSTDEVICE inline float16 uint16_to_half(uint16_t x);
} // namespace fp16_impl } // namespace fp16_impl
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient // and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible // memory access of float16 struct and also makes float16 compatible
// with CUDA half and Eigen::half data types. // with CUDA half, ARM float16_t, and Eigen::half data types.
struct PADDLE_ALIGN(2) float16 { struct PADDLE_ALIGN(2) float16 {
uint16_t x; uint16_t x;
...@@ -103,7 +107,7 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -103,7 +107,7 @@ struct PADDLE_ALIGN(2) float16 {
PADDLE_HOSTDEVICE inline float16(const float16& h) : x(h.x) {} PADDLE_HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE inline float16(const half h) { PADDLE_HOSTDEVICE inline float16(const half& h) {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
x = reinterpret_cast<__half_raw*>(&h)->x; x = reinterpret_cast<__half_raw*>(&h)->x;
#else #else
...@@ -111,40 +115,72 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -111,40 +115,72 @@ struct PADDLE_ALIGN(2) float16 {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
} }
#endif // PADDLE_CUDA_FP16 #endif // PADDLE_CUDA_FP16
/*
#ifdef PADDLE_CUDA_FP16
#if CUDA_VERSION < 9000
PADDLE_HOSTDEVICE inline float16(const half& h) : x(h.x) {}
#else
PADDLE_HOSTDEVICE inline float16(const __half_raw& h) : x(h.x) {}
PADDLE_HOSTDEVICE inline float16(const half& h)
: x(*reinterpret_cast<uint16_t*>(&h)) {}
#endif // CUDA_VERSION < 9000
#endif // PADDLE_CUDA_FP16
*/
#ifdef USE_EIGEN #ifdef USE_EIGEN
PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {} PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {}
#endif // USE_EIGEN #endif // USE_EIGEN
#if defined(PADDLE_ARM) && defined(PADDLE_NEON) #ifdef PADDLE_NEON
// __fp16 is a native half precision data type for arm cpu, // __fp16 is a native half precision data type for arm cpu,
// float16_t is an alias for __fp16 in arm_fp16.h // float16_t is an alias for __fp16 in arm_fp16.h,
// which is included in arm_neon.h // which is included in arm_neon.h.
PADDLE_HOSTDEVICE inline float16(const float16_t h) { // According to gcc, __fp16 can only be used as an argument to fp16
x = *reinterpret_cast<uint16_t*>(&h); // intrinsic defined in arm_neon.h or as a storage type. It cannot
// be used as a formal function argument.
// TODO (kexinzhao): test it on RPI
PADDLE_HOSTDEVICE inline float16(const float16_t* h) {
x = *reinterpret_cast<uint16_t*>(h);
} }
#endif #endif
PADDLE_HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} PADDLE_HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
PADDLE_HOSTDEVICE inline explicit float16(int8_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(uint8_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(int16_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(uint16_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(int32_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(uint32_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(int64_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(uint64_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
}
PADDLE_HOSTDEVICE inline explicit float16(float val) { PADDLE_HOSTDEVICE inline explicit float16(float val) {
float16 res = fp16_impl::float_to_half_rn(val); float16 res = fp16_impl::float_to_half_rn(val);
x = res.x; x = res.x;
} }
template <class T> PADDLE_HOSTDEVICE inline explicit float16(double val) {
PADDLE_HOSTDEVICE inline explicit float16(const T& val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val)); float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x; x = res.x;
} }
...@@ -155,7 +191,7 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -155,7 +191,7 @@ struct PADDLE_ALIGN(2) float16 {
} }
#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE inline float16& operator=(const half rhs) { PADDLE_HOSTDEVICE inline float16& operator=(const half& rhs) {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
x = reinterpret_cast<__half_raw*>(&rhs)->x; x = reinterpret_cast<__half_raw*>(&rhs)->x;
#else #else
...@@ -172,27 +208,80 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -172,27 +208,80 @@ struct PADDLE_ALIGN(2) float16 {
} }
#endif // USE_EIGEN #endif // USE_EIGEN
#if defined(PADDLE_ARM) && defined(PADDLE_NEON) #ifdef PADDLE_NEON
PADDLE_HOSTDEVICE inline float16& operator=(const float16_t rhs) { PADDLE_HOSTDEVICE inline float16& operator=(const float16_t* rhs) {
x = *reinterpret_cast<uint16_t*>(&rhs); x = *reinterpret_cast<uint16_t*>(rhs);
return *this; return *this;
} }
#endif #endif
/* PADDLE_HOSTDEVICE inline float16& operator=(bool b) {
PADDLE_HOSTDEVICE inline explicit float16(int val) { x = b ? 0x3c00 : 0;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(int8_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val)); float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x; x = res.x;
return *this;
} }
PADDLE_HOSTDEVICE inline explicit float16(double val) { PADDLE_HOSTDEVICE inline float16& operator=(uint8_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val)); float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x; x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(int16_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(uint16_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(int32_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(uint32_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(int64_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(uint64_t val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(float val) {
float16 res = fp16_impl::float_to_half_rn(val);
x = res.x;
return *this;
}
PADDLE_HOSTDEVICE inline float16& operator=(double val) {
float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
x = res.x;
return *this;
} }
*/
#ifdef PADDLE_CUDA_FP16 #ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE inline operator half() { PADDLE_HOSTDEVICE inline operator half() const {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
__half_raw h; __half_raw h;
h.x = x; h.x = x;
...@@ -206,82 +295,270 @@ struct PADDLE_ALIGN(2) float16 { ...@@ -206,82 +295,270 @@ struct PADDLE_ALIGN(2) float16 {
#endif // PADDLE_CUDA_FP16 #endif // PADDLE_CUDA_FP16
#ifdef USE_EIGEN #ifdef USE_EIGEN
PADDLE_HOSTDEVICE inline operator Eigen::half() { PADDLE_HOSTDEVICE inline operator Eigen::half() const {
Eigen::half h; Eigen::half h;
h.x = x; h.x = x;
return h; return h;
} }
#endif // USE_EIGEN #endif // USE_EIGEN
#if defined(PADDLE_ARM) && defined(PADDLE_NEON) #ifdef PADDLE_NEON
PADDLE_HOSTDEVICE inline operator float16_t() { // check whether it works or not
PADDLE_HOSTDEVICE inline operator float16_t() const {
float16 h = *this; float16 h = *this;
return *reinterpret_cast<float16_t*>(&h); return *reinterpret_cast<float16_t*>(&h);
} }
#endif #endif
PADDLE_HOSTDEVICE inline explicit operator bool() { PADDLE_HOSTDEVICE inline explicit operator bool() const {
return (x & 0x7fff) != 0; return (x & 0x7fff) != 0;
} }
PADDLE_HOSTDEVICE inline explicit operator int8_t() { PADDLE_HOSTDEVICE inline explicit operator int8_t() const {
return static_cat<int8_t>(fp16_impl::half_to_float(*this)); return static_cast<int8_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator uint8_t() { PADDLE_HOSTDEVICE inline explicit operator uint8_t() const {
return static_cat<uint8_t>(fp16_impl::half_to_float(*this)); return static_cast<uint8_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator int16_t() { PADDLE_HOSTDEVICE inline explicit operator int16_t() const {
return static_cat<int16_t>(fp16_impl::half_to_float(*this)); return static_cast<int16_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator uint16_t() { PADDLE_HOSTDEVICE inline explicit operator uint16_t() const {
return static_cat<uint16_t>(fp16_impl::half_to_float(*this)); return static_cast<uint16_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator int32_t() { PADDLE_HOSTDEVICE inline explicit operator int32_t() const {
return static_cat<int32_t>(fp16_impl::half_to_float(*this)); return static_cast<int32_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator uint32_t() { PADDLE_HOSTDEVICE inline explicit operator uint32_t() const {
return static_cat<uint32_t>(fp16_impl::half_to_float(*this)); return static_cast<uint32_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator int64_t() { PADDLE_HOSTDEVICE inline explicit operator int64_t() const {
return static_cat<int64_t>(fp16_impl::half_to_float(*this)); return static_cast<int64_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator uint64_t() { PADDLE_HOSTDEVICE inline explicit operator uint64_t() const {
return static_cat<uint64_t>(fp16_impl::half_to_float(*this)); return static_cast<uint64_t>(fp16_impl::half_to_float(*this));
} }
PADDLE_HOSTDEVICE inline explicit operator float() { PADDLE_HOSTDEVICE inline explicit operator float() const {
return fp16_impl::half_to_float(*this); return fp16_impl::half_to_float(*this);
} }
PADDLE_HOSTDEVICE inline explicit operator double() { PADDLE_HOSTDEVICE inline explicit operator double() const {
return static_cat<double>(fp16_impl::half_to_float(*this)); return static_cast<double>(fp16_impl::half_to_float(*this));
} }
}; };
// arithmetic operators // arithmetic operators
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__ inline float16 operator+(const float16& a, const float16& b) { __device__ inline float16 operator+(const float16& a, const float16& b) {
return float16(__hadd(a, b)); return float16(__hadd(half(a), half(b)));
} }
__device__ inline float16 operator-(const float16& a, const float16& b) { __device__ inline float16 operator-(const float16& a, const float16& b) {
return __hsub(a, b); return float16(__hsub(half(a), half(b)));
} }
__device__ inline float16 operator*(const float16& a, const float16& b) { __device__ inline float16 operator*(const float16& a, const float16& b) {
return __hmul(a, b); return float16(__hmul(half(a), half(b)));
} }
#elif // on arm cpu __device__ inline float16 operator/(const float16& a, const float16& b) {
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
// instinsic
float num = __half2float(half(a));
float denom = __half2float(half(b));
return float16(num / denom);
}
#else __device__ inline float16 operator-(const float16& a) {
return float16(__hneg(half(a)));
}
__device__ inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
__device__ inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
__device__ inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
__device__ inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
__device__ inline bool operator==(const float16& a, const float16& b) {
return __heq(half(a), half(b));
}
__device__ inline bool operator!=(const float16& a, const float16& b) {
return __hne(half(a), half(b));
}
__device__ inline bool operator<(const float16& a, const float16& b) {
return __hlt(half(a), half(b));
}
__device__ inline bool operator<=(const float16& a, const float16& b) {
return __hle(half(a), half(b));
}
__device__ inline bool operator>(const float16& a, const float16& b) {
return __hgt(half(a), half(b));
}
__device__ inline bool operator>=(const float16& a, const float16& b) {
return __hge(half(a), half(b));
}
// On ARMv8.2-A CPU
#elif (PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39) && \
defined(PADDLE_NEON_64) && defined(PADDLE_ARM_FP16)
__host__ inline float16 operator+(const float16& a, const float16& b) {
return float16(vaddh_f16(float16_t(a), float16_t(b)));
}
__host__ inline float16 operator-(const float16& a, const float16& b) {
return float16(vsubh_f16(float16_t(a), float16_t(b)));
}
__host__ inline float16 operator*(const float16& a, const float16& b) {
return float16(vmulh_f16(float16_t(a), float16_t(b)));
}
__host__ inline float16 operator/(const float16& a, const float16& b) {
return float16(vdivh_f16(float16_t(a), float16_t(b)));
}
__host__ inline float16 operator-(const float16& a) {
return float16(vnegh_f16(float16_t(a)));
}
__host__ inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
__host__ inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
__host__ inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
__host__ inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
__host__ inline bool operator==(const float16& a, const float16& b) {
return static_cast<bool>(vceqh_f16(float16_t(a), float16_t(b)));
}
__host__ inline bool operator!=(const float16& a, const float16& b) {
return !(a == b);
}
// compare only available in NEON_64
__host__ inline bool operator<(const float16& a, const float16& b) {
return static_cast<bool>(vclth_f16(float16_t(a), float16_t(b)));
}
__host__ inline bool operator<=(const float16& a, const float16& b) {
return static_cast<bool>(vcleh_f16(float16_t(a), float16_t(b)));
}
__host__ inline bool operator>(const float16& a, const float16& b) {
return static_cast<bool>(vcgth_f16(float16_t(a), float16_t(b)));
}
__host__ inline bool operator>=(const float16& a, const float16& b) {
return static_cast<bool>(vcgeh_f16(float16_t(a), float16_t(b)));
}
#else // software emulation on other cpu
PADDLE_HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b));
}
PADDLE_HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}
PADDLE_HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}
PADDLE_HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}
PADDLE_HOSTDEVICE inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}
PADDLE_HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}
PADDLE_HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}
PADDLE_HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}
PADDLE_HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}
PADDLE_HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}
PADDLE_HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}
PADDLE_HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}
PADDLE_HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}
PADDLE_HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}
PADDLE_HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif #endif
...@@ -320,16 +597,11 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) { ...@@ -320,16 +597,11 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
half tmp = __float2half(f); half tmp = __float2half(f);
return *reinterpret_cast<float16*>(&(tmp)); return *reinterpret_cast<float16*>(&(tmp));
#elif defined(__F16C__) #elif defined(PADDLE_NEON_64) // test on RPI
float16 res;
res.x = _cvtss_sh(f, 0);
return res;
#elif defined(PADDLE_ARM_64) // test on RPI
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.s}[0], [%[float_ptr]]\n" "ld1 {v0.s}[0], [%[float_ptr]]\n"
"FCVT h0, s0\n" "fcvt h0, s0\n"
"st1 {v0.h}[0], [%[half_ptr]]\n" "st1 {v0.h}[0], [%[half_ptr]]\n"
: // outputs : // outputs
: // inputs : // inputs
...@@ -339,6 +611,25 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) { ...@@ -339,6 +611,25 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
"memory", "v0"); "memory", "v0");
return res; return res;
#elif defined(PADDLE_NEON_32) // test on RPI
float16 res;
asm volatile(
"vld1.32 {d0[0]}, [%[float_ptr]]\n"
"vcvt.f16.f32 d0, q0\n"
"vst1.16 {d0[0]}, [%[half_ptr]]\n"
: // outputs
: // inputs
[float_ptr] "r"(&f),
[half_ptr] "r"(&(res.x))
: // clobbers
"memory", "d0");
return res;
#elif defined(__F16C__)
float16 res;
res.x = _cvtss_sh(f, 0);
return res;
#else #else
// Conversion routine adapted from // Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
...@@ -367,10 +658,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) { ...@@ -367,10 +658,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
half tmp = *reinterpret_cast<half*>(&h); half tmp = *reinterpret_cast<half*>(&h);
return __half2float(h); return __half2float(h);
#elif defined(__F16C__) #elif defined(PADDLE_NEON_64)
return _cvtsh_ss(h.x);
#elif defined(PADDLE_ARM_64) // test on RPI
float res; float res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[half_ptr]]\n" "ld1 {v0.h}[0], [%[half_ptr]]\n"
...@@ -384,6 +672,23 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) { ...@@ -384,6 +672,23 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
"memory", "v0"); "memory", "v0");
return res; return res;
#elif defined(PADDLE_NEON_32)
float res;
asm volatile(
"vld1.16 {d0[0]}, [%[half_ptr]]\n"
"vcvt.f32.f16 q0, d0\n"
"vst1.32 {d0[0]}, [%[float_ptr]]\n"
: // outputs
: // inputs
[half_ptr] "r"(&(h.x)),
[float_ptr] "r"(&res)
: // clobbers
"memory", "v0");
return res;
#elif defined(__F16C__)
return _cvtsh_ss(h.x);
#else #else
// Conversion routine adapted from // Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
...@@ -406,12 +711,6 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) { ...@@ -406,12 +711,6 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
#endif #endif
} }
PADDLE_HOSTDEVICE inline float16 uint16_to_half(uint16_t x) {
float16 res;
res.x = x;
return res;
}
} // namespace half_impl } // namespace half_impl
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册