未验证 提交 1d1555e2 编写于 作者: K kexinzhao 提交者: GitHub

Merge pull request #5716 from kexinzhao/float16

Add half precision float16 data type
......@@ -58,6 +58,7 @@ option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
......
......@@ -24,6 +24,11 @@ if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE)
endif(WITH_DOUBLE)
if(WITH_ARM_FP16)
add_definitions(-DPADDLE_ARM_FP16)
add_definitions("-march=armv8.2-a+fp16+simd")
endif(WITH_ARM_FP16)
if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
#include "paddle/platform/hostdevice.h"
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif // __GNUC__
#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif // __clang__
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#endif
#if defined(__arm__) || defined(__aarch64__)
#define PADDLE_ARM
#endif
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
#include <arm_neon.h>
#endif
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
#endif
#ifndef PADDLE_ARM
#include <immintrin.h>
#endif // PADDLE_ARM
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
namespace paddle {
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
// with CUDA half, ARM float16_t, and Eigen::half data types.
struct PADDLE_ALIGN(2) float16 {
public:
uint16_t x;
// Constructors
HOSTDEVICE inline float16() : x(0) {}
HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE inline explicit float16(const half& h) {
#if CUDA_VERSION >= 9000
x = reinterpret_cast<__half_raw*>(&h)->x;
#else
x = h.x;
#endif // CUDA_VERSION >= 9000
}
#endif // PADDLE_CUDA_FP16
HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
#ifdef PADDLE_WITH_NATIVE_FP16
// __fp16 is a native half precision data type for arm cpu,
// float16_t is an alias for __fp16
HOSTDEVICE inline explicit float16(const float16_t& h) {
x = *reinterpret_cast<const uint16_t*>(&h);
}
#endif
HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
half tmp = __float2half(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#elif defined(PADDLE_NEON)
float32x4_t tmp = vld1q_dup_f32(&val);
float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
x = *reinterpret_cast<uint16_t*>(&res);
#elif defined(__F16C__)
x = _cvtss_sh(val, 0);
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v, s;
v.f = val;
uint32_t sign = v.si & sigN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
x = v.ui | sign;
#endif
}
HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
template <class T>
HOSTDEVICE inline explicit float16(const T& val)
: x(float16(static_cast<float>(val)).x) {}
HOSTDEVICE inline float16& operator=(const float16& rhs) {
x = rhs.x;
return *this;
}
// Assignment operators
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE inline float16& operator=(const half& rhs) {
#if CUDA_VERSION >= 9000
x = reinterpret_cast<__half_raw*>(&rhs)->x;
#else
x = rhs.x;
#endif
return *this;
}
#endif
HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
x = rhs.x;
return *this;
}
#ifdef PADDLE_WITH_NATIVE_FP16
HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
x = *reinterpret_cast<const uint16_t*>(&rhs);
return *this;
}
#endif
HOSTDEVICE inline float16& operator=(bool b) {
x = b ? 0x3c00 : 0;
return *this;
}
HOSTDEVICE inline float16& operator=(int8_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint8_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int16_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint16_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int32_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint32_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int64_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint64_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(float val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(double val) {
x = float16(val).x;
return *this;
}
// Conversion opertors
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE inline explicit operator half() const {
#if CUDA_VERSION >= 9000
__half_raw h;
h.x = x;
return half(h);
#else
half h;
h.x = x;
return h;
#endif // CUDA_VERSION >= 9000
}
#endif // PADDLE_CUDA_FP16
HOSTDEVICE inline explicit operator Eigen::half() const {
Eigen::half h;
h.x = x;
return h;
}
#ifdef PADDLE_WITH_NATIVE_FP16
HOSTDEVICE inline explicit operator float16_t() const {
return *reinterpret_cast<const float16_t*>(this);
}
#endif
HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
half tmp = *reinterpret_cast<const half*>(this);
return __half2float(tmp);
#elif defined(PADDLE_NEON)
float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
return vgetq_lane_f32(vcvt_f32_f16(res), 0);
#elif defined(__F16C__)
return _cvtsh_ss(this->x);
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v;
v.ui = this->x;
int32_t sign = v.si & sigC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
#endif
}
HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(float(*this));
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(float(*this));
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(float(*this));
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(float(*this));
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(float(*this));
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(float(*this));
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(float(*this));
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(float(*this));
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(float(*this));
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static const int shift = 13;
static const int shiftSign = 16;
static const int32_t infN = 0x7F800000;
static const int32_t maxN = 0x477FE000; // max flt16 as flt32
static const int32_t minN = 0x38800000; // min flt16 normal as flt32
static const int32_t sigN = 0x80000000; // sign bit
static constexpr int32_t infC = infN >> shift;
static constexpr int32_t nanN = (infC + 1)
<< shift; // minimum flt16 nan as float32
static constexpr int32_t maxC = maxN >> shift;
static constexpr int32_t minC = minN >> shift;
static constexpr int32_t sigC = sigN >> shiftSign;
static const int32_t mulN = 0x52000000; // (1 << 23) / minN
static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift))
static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted
static const int32_t norC = 0x00400; // min flt32 normal downshifted
static constexpr int32_t maxD = infC - maxC - 1;
static constexpr int32_t minD = minC - subC - 1;
};
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000
DEVICE inline half operator+(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hadd(a, b);
#else
float res = float(float16(a)) + float(float16(b));
return half(float16(res));
#endif
}
DEVICE inline half operator-(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hsub(a, b);
#else
float res = float(float16(a)) - float(float16(b));
return half(float16(res));
#endif
}
DEVICE inline half operator*(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hmul(a, b);
#else
float res = float(float16(a)) * float(float16(b));
return half(float16(res));
#endif
}
DEVICE inline half operator/(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
float num = __half2float(a);
float denom = __half2float(b);
return __float2half(num / denom);
#else
float res = float(float16(a)) / float(float16(b));
return half(float16(res));
#endif
}
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hneg(a);
#else
float res = -float(float16(a));
return half(float16(res));
#endif
}
DEVICE inline half& operator+=(half& a, const half& b) {
a = a + b;
return a;
}
DEVICE inline half& operator-=(half& a, const half& b) {
a = a - b;
return a;
}
DEVICE inline half& operator*=(half& a, const half& b) {
a = a * b;
return a;
}
DEVICE inline half& operator/=(half& a, const half& b) {
a = a / b;
return a;
}
DEVICE inline bool operator==(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(a, b);
#else
return float(float16(a)) == float(float16(b));
#endif
}
DEVICE inline bool operator!=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(a, b);
#else
return float(float16(a)) != float(float16(b));
#endif
}
DEVICE inline bool operator<(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(a, b);
#else
return float(float16(a)) < float(float16(b));
#endif
}
DEVICE inline bool operator<=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(a, b);
#else
return float(float16(a)) <= float(float16(b));
#endif
}
DEVICE inline bool operator>(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(a, b);
#else
return float(float16(a)) > float(float16(b));
#endif
}
DEVICE inline bool operator>=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(a, b);
#else
return float(float16(a)) >= float(float16(b));
#endif
}
#endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU
#if defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fadd h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&(res.x))
: // clobbers
"memory", "v0", "v1");
return res;
}
HOST inline float16 operator-(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fsub h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&(res.x))
: // clobbers
"memory", "v0", "v1");
return res;
}
HOST inline float16 operator*(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fmul h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&(res.x))
: // clobbers
"memory", "v0", "v1");
return res;
}
HOST inline float16 operator/(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fdiv h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&(res.x))
: // clobbers
"memory", "v0", "v1");
return res;
}
HOST inline float16 operator-(const float16& a) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"fneg h0, h0\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[res_ptr] "r"(&(res.x))
: // clobbers
"memory", "v0");
return res;
}
HOST inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
HOST inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
HOST inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
HOST inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
HOST inline bool operator==(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fcmeq h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&res)
: // clobbers
"memory", "v0", "v1");
return (res & 0xffff) != 0;
}
HOST inline bool operator!=(const float16& a, const float16& b) {
return !(a == b);
}
HOST inline bool operator<(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n"
"ld1 {v0.h}[0], [%[b_ptr]]\n"
"fcmgt h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&res)
: // clobbers
"memory", "v0", "v1");
return (res & 0xffff) != 0;
}
HOST inline bool operator<=(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n"
"ld1 {v0.h}[0], [%[b_ptr]]\n"
"fcmge h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&res)
: // clobbers
"memory", "v0", "v1");
return (res & 0xffff) != 0;
}
HOST inline bool operator>(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fcmgt h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&res)
: // clobbers
"memory", "v0", "v1");
return (res & 0xffff) != 0;
}
HOST inline bool operator>=(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
"ld1 {v1.h}[0], [%[b_ptr]]\n"
"fcmge h0, h0, h1\n"
"st1 {v0.h}[0], [%[res_ptr]]\n"
: // outputs
: // inputs
[a_ptr] "r"(&(a.x)),
[b_ptr] "r"(&(b.x)),
[res_ptr] "r"(&res)
: // clobbers
"memory", "v0", "v1");
return (res & 0xffff) != 0;
}
// Arithmetic operators, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
} // namespace paddle
......@@ -22,6 +22,7 @@ if(WITH_GPU)
link_paddle_test(test_Tensor)
CUDA_ADD_EXECUTABLE(test_lazyAssign test_lazyAssign.cu)
link_paddle_test(test_lazyAssign)
nv_test(test_float16_gpu SRCS test_float16.cu)
else()
compile_cu_as_cpp(test_Tensor.cu)
add_unittest(test_Tensor test_Tensor.cu)
......@@ -33,3 +34,4 @@ add_simple_unittest(test_FPException)
add_simple_unittest(test_GpuProfiler)
add_simple_unittest(test_BaseMatrix)
add_simple_unittest(test_Matrix)
cc_test(test_float16 SRCS test_float16.cpp)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/math/float16.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(float16, conversion_cpu) {
// Explicit conversion from Eigen::half
EXPECT_EQ(float16(Eigen::half(1.0f)).x, 0x3c00);
EXPECT_EQ(float16(Eigen::half(0.5f)).x, 0x3800);
EXPECT_EQ(float16(Eigen::half(0.33333f)).x, 0x3555);
EXPECT_EQ(float16(Eigen::half(0.0f)).x, 0x0000);
EXPECT_EQ(float16(Eigen::half(-0.0f)).x, 0x8000);
EXPECT_EQ(float16(Eigen::half(65504.0f)).x, 0x7bff);
EXPECT_EQ(float16(Eigen::half(65536.0f)).x, 0x7c00);
// Conversion from float
EXPECT_EQ(float16(1.0f).x, 0x3c00);
EXPECT_EQ(float16(0.5f).x, 0x3800);
EXPECT_EQ(float16(0.33333f).x, 0x3555);
EXPECT_EQ(float16(0.0f).x, 0x0000);
EXPECT_EQ(float16(-0.0f).x, 0x8000);
EXPECT_EQ(float16(65504.0f).x, 0x7bff);
EXPECT_EQ(float16(65536.0f).x, 0x7c00);
// Conversion from double
EXPECT_EQ(float16(1.0).x, 0x3c00);
EXPECT_EQ(float16(0.5).x, 0x3800);
EXPECT_EQ(float16(0.33333).x, 0x3555);
EXPECT_EQ(float16(0.0).x, 0x0000);
EXPECT_EQ(float16(-0.0).x, 0x8000);
EXPECT_EQ(float16(65504.0).x, 0x7bff);
EXPECT_EQ(float16(65536.0).x, 0x7c00);
// Conversion from int
EXPECT_EQ(float16(-1).x, 0xbc00);
EXPECT_EQ(float16(0).x, 0x0000);
EXPECT_EQ(float16(1).x, 0x3c00);
EXPECT_EQ(float16(2).x, 0x4000);
EXPECT_EQ(float16(3).x, 0x4200);
// Conversion from bool
EXPECT_EQ(float16(true).x, 0x3c00);
EXPECT_EQ(float16(false).x, 0x0000);
// Default constructor
float16 v_def;
EXPECT_EQ(v_def.x, 0x0000);
// Assignment operator
float16 v_assign;
v_assign = v_def;
EXPECT_EQ(v_assign.x, 0x0000);
v_assign = Eigen::half(1.0f);
EXPECT_EQ(v_assign.x, 0x3c00);
v_assign = 0.5f;
EXPECT_EQ(v_assign.x, 0x3800);
v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3555);
v_assign = -1;
EXPECT_EQ(v_assign.x, 0xbc00);
v_assign = true;
EXPECT_EQ(v_assign.x, 0x3c00);
// Conversion operator
EXPECT_EQ(Eigen::half(float16(1.0f)).x, 0x3c00);
EXPECT_EQ(float(float16(0.5f)), 0.5f);
EXPECT_NEAR(double(float16(0.33333)), 0.33333, 0.0001);
EXPECT_EQ(int(float16(-1)), -1);
EXPECT_EQ(bool(float16(true)), true);
}
TEST(float16, arithmetic_cpu) {
EXPECT_EQ(float(float16(1) + float16(1)), 2);
EXPECT_EQ(float(float16(5) + float16(-5)), 0);
EXPECT_NEAR(float(float16(0.33333f) + float16(0.66667f)), 1.0f, 0.001);
EXPECT_EQ(float(float16(3) - float16(5)), -2);
EXPECT_NEAR(float(float16(0.66667f) - float16(0.33333f)), 0.33334f, 0.001);
EXPECT_NEAR(float(float16(3.3f) * float16(2.0f)), 6.6f, 0.01);
EXPECT_NEAR(float(float16(-2.1f) * float16(-3.0f)), 6.3f, 0.01);
EXPECT_NEAR(float(float16(2.0f) / float16(3.0f)), 0.66667f, 0.001);
EXPECT_EQ(float(float16(1.0f) / float16(2.0f)), 0.5f);
EXPECT_EQ(float(-float16(512.0f)), -512.0f);
EXPECT_EQ(float(-float16(-512.0f)), 512.0f);
}
TEST(float16, comparison_cpu) {
EXPECT_TRUE(float16(1.0f) == float16(1.0f));
EXPECT_FALSE(float16(-1.0f) == float16(-0.5f));
EXPECT_TRUE(float16(1.0f) != float16(0.5f));
EXPECT_FALSE(float16(-1.0f) != float16(-1.0f));
EXPECT_TRUE(float16(1.0f) < float16(2.0f));
EXPECT_FALSE(float16(-1.0f) < float16(-1.0f));
EXPECT_TRUE(float16(1.0f) <= float16(1.0f));
EXPECT_TRUE(float16(2.0f) > float16(1.0f));
EXPECT_FALSE(float16(-2.0f) > float16(-2.0f));
EXPECT_TRUE(float16(2.0f) >= float16(2.0f));
EXPECT_TRUE(float16(0.0f) == float16(-0.0f));
EXPECT_TRUE(float16(0.0f) <= float16(-0.0f));
EXPECT_TRUE(float16(0.0f) >= float16(-0.0f));
EXPECT_FALSE(float16(0.0f) < float16(-0.0f));
EXPECT_FALSE(float16(-0.0f) < float16(0.0f));
EXPECT_FALSE(float16(0.0f) > float16(-0.0f));
EXPECT_FALSE(float16(-0.0f) > float16(0.0f));
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/math/float16.h"
#include <gtest/gtest.h>
#include "paddle/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, half* out) { \
out[0] = in1[0] sign in2[0]; \
}
#define COMPOUND_KERNEL(op_type, sign) \
__global__ void op_type(half* in1, const half* in2) { in1[0] sign in2[0]; }
#define COMPARISON_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, bool* out) { \
out[0] = in1[0] sign in2[0]; \
}
#define ARITHMETIC_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2, *out; \
half *d_in1, *d_in2, *d_out; \
int size = sizeof(half); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
cudaMalloc((void**)&d_out, size); \
in1 = (half*)malloc(size); \
in2 = (half*)malloc(size); \
out = (half*)malloc(size); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2, d_out); \
cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(float(float16(out[0])), v_out); \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
cudaFree(d_out); \
}
#define COMPOUND_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2; \
half *d_in1, *d_in2; \
int size = sizeof(half); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
in1 = (half*)malloc(size); \
in2 = (half*)malloc(size); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2); \
cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(float(float16(in1[0])), v_out); \
free(in1); \
free(in2); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
#define COMPARISON_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, bool v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2; \
half *d_in1, *d_in2; \
bool *out, *d_out; \
int size = sizeof(half); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
cudaMalloc((void**)&d_out, 1); \
in1 = (half*)malloc(size); \
in2 = (half*)malloc(size); \
out = (bool*)malloc(1); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2, d_out); \
cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost); \
EXPECT_EQ(out[0], v_out); \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
cudaFree(d_out); \
}
#ifdef PADDLE_CUDA_FP16
namespace paddle {
#if CUDA_VERSION < 9000
ARITHMETIC_KERNEL(Add, +)
ARITHMETIC_KERNEL(Sub, -)
ARITHMETIC_KERNEL(Mul, *)
ARITHMETIC_KERNEL(Div, /)
ARITHMETIC_KERNEL_LAUNCH(Add)
ARITHMETIC_KERNEL_LAUNCH(Sub)
ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH(Div)
// Negative sign kernel
__global__ void Neg(half* in) { in[0] = -in[0]; }
void TestNeg(float v_in, float v_out) {
LOG(INFO) << "Test Neg on GPU!";
half *in, *d_in;
int size = sizeof(half);
cudaMalloc((void**)&d_in, size);
in = (half*)malloc(size);
in[0] = half(float16(v_in));
cudaMemcpy(d_in, in, size, cudaMemcpyHostToDevice);
Neg<<<1, 1>>>(d_in);
cudaMemcpy(in, d_in, size, cudaMemcpyDeviceToHost);
EXPECT_EQ(float(float16(in[0])), v_out);
free(in);
cudaFree(d_in);
}
COMPOUND_KERNEL(AddAssign, +=)
COMPOUND_KERNEL(SubAssign, -=)
COMPOUND_KERNEL(MulAssign, *=)
COMPOUND_KERNEL(DivAssign, /=)
COMPOUND_KERNEL_LAUNCH(AddAssign)
COMPOUND_KERNEL_LAUNCH(SubAssign)
COMPOUND_KERNEL_LAUNCH(MulAssign)
COMPOUND_KERNEL_LAUNCH(DivAssign)
COMPARISON_KERNEL(Equal, ==)
COMPARISON_KERNEL(NotEqual, !=)
COMPARISON_KERNEL(Less, <)
COMPARISON_KERNEL(LessEqual, <=)
COMPARISON_KERNEL(Greater, >)
COMPARISON_KERNEL(GreaterEqual, >=)
COMPARISON_KERNEL_LAUNCH(Equal)
COMPARISON_KERNEL_LAUNCH(NotEqual)
COMPARISON_KERNEL_LAUNCH(Less)
COMPARISON_KERNEL_LAUNCH(LessEqual)
COMPARISON_KERNEL_LAUNCH(Greater)
COMPARISON_KERNEL_LAUNCH(GreaterEqual)
TEST(float16, arithmetic_on_gpu) {
TestAdd(1, 2, 3);
TestSub(2, 1, 1);
TestMul(2, 3, 6);
TestDiv(6, 2, 3);
TestNeg(1, -1);
}
TEST(float16, compound_on_gpu) {
TestAddAssign(1, 2, 3);
TestSubAssign(2, 1, 1);
TestMulAssign(2, 3, 6);
TestDivAssign(6, 2, 3);
}
TEST(float16, comparision_on_gpu) {
TestEqual(1, 1, true);
TestEqual(1, 2, false);
TestNotEqual(2, 3, true);
TestNotEqual(2, 2, false);
TestLess(3, 4, true);
TestLess(3, 3, false);
TestLessEqual(3, 3, true);
TestLessEqual(3, 2, false);
TestGreater(4, 3, true);
TestGreater(4, 4, false);
TestGreaterEqual(4, 4, true);
TestGreaterEqual(4, 5, false);
}
#endif // CUDA_VERSION
TEST(float16, conversion_on_gpu) {
// Explicit conversion to and from cuda half
EXPECT_EQ(float16(half(float16(1.0f))).x, 0x3c00);
EXPECT_EQ(float16(half(float16(0.5f))).x, 0x3800);
EXPECT_EQ(float16(half(float16(0.33333f))).x, 0x3555);
EXPECT_EQ(float16(half(float16(0.0f))).x, 0x0000);
EXPECT_EQ(float16(half(float16(-0.0f))).x, 0x8000);
EXPECT_EQ(float16(half(float16(65504.0f))).x, 0x7bff);
EXPECT_EQ(float16(half(float16(65536.0f))).x, 0x7c00);
// Assignment operator
float16 v_assign;
v_assign = half(float16(1.0f));
EXPECT_EQ(v_assign.x, 0x3c00);
}
} // namespace paddle
#endif // PADDLE_CUDA_FP16
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册