Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0f4bf1c9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0f4bf1c9
编写于
11月 19, 2017
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GPU device code for testing
上级
734cac1a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
296 addition
and
94 deletion
+296
-94
paddle/math/float16.h
paddle/math/float16.h
+14
-57
paddle/math/tests/test_float16.cpp
paddle/math/tests/test_float16.cpp
+89
-13
paddle/math/tests/test_float16.cu
paddle/math/tests/test_float16.cu
+193
-24
未找到文件。
paddle/math/float16.h
浏览文件 @
0f4bf1c9
...
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// need to define PADDLE_ARM_FP16
#pragma once
#include <cstdint>
...
...
@@ -21,14 +19,7 @@ limitations under the License. */
#include <ostream>
#include <cuda.h>
#include "paddle/utils/Logging.h"
#define USE_EIGEN
#ifdef USE_EIGEN // delete this #if macro
#include "unsupported/Eigen/CXX11/Tensor"
#endif
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
...
...
@@ -52,27 +43,6 @@ limitations under the License. */
#define PADDLE_HOSTDEVICE
#endif // __CUDACC__
#define STR(x) #x
#define XSTR(x) STR(x)
#ifndef __CUDACC__
#pragma message "__CUDACC__ not defined"
#else
#pragma message "__CUDACC__ defined"
#endif
#ifndef CUDA_VERSION
#pragma message "CUDA_VERSION not defined"
#else
#pragma message "CUDA_VERSION defined: " XSTR(CUDA_VERSION)
#endif
#ifdef __CUDA_ARCH__
#pragma message "The value of CUDA_ARCH: " XSTR(__CUDA_ARCH__)
#else
#pragma message "CUDA ARCH NOT DEFINED!"
#endif
#ifdef __arm__
#define PADDLE_ARM_32
#endif
...
...
@@ -113,7 +83,7 @@ namespace paddle {
struct
float16
;
namespace
fp16_impl
{
//
c
onvert from float to half precision in round-to-nearest-even mode
//
C
onvert from float to half precision in round-to-nearest-even mode
PADDLE_HOSTDEVICE
inline
float16
float_to_half_rn
(
float
f
);
PADDLE_HOSTDEVICE
inline
float
half_to_float
(
float16
h
);
}
// namespace fp16_impl
...
...
@@ -125,7 +95,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
struct
PADDLE_ALIGN
(
2
)
float16
{
uint16_t
x
;
PADDLE_HOSTDEVICE
inline
float16
()
{}
PADDLE_HOSTDEVICE
inline
float16
()
:
x
(
0
)
{}
PADDLE_HOSTDEVICE
inline
float16
(
const
float16
&
h
)
:
x
(
h
.
x
)
{}
...
...
@@ -139,21 +109,15 @@ struct PADDLE_ALIGN(2) float16 {
}
#endif // PADDLE_CUDA_FP16
#ifdef USE_EIGEN
PADDLE_HOSTDEVICE
inline
float16
(
const
Eigen
::
half
&
h
)
:
x
(
h
.
x
)
{}
#endif // USE_EIGEN
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
// __fp16 is a native half precision data type for arm cpu,
// float16_t is an alias for __fp16 in arm_fp16.h,
// which is included in arm_neon.h.
// According to gcc, __fp16 can only be used as an argument to fp16
// intrinsic defined in arm_neon.h or as a storage type. It cannot
// be used as a formal function argument.
// TODO(kexinzhao): test it on RPI
PADDLE_HOSTDEVICE
inline
float16
(
const
float16_t
*
h
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
h
);
PADDLE_HOSTDEVICE
inline
float16
(
const
float16_t
&
h
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
&
h
);
}
#endif
...
...
@@ -225,17 +189,15 @@ struct PADDLE_ALIGN(2) float16 {
}
#endif
#ifdef USE_EIGEN
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
Eigen
::
half
&
rhs
)
{
x
=
rhs
.
x
;
return
*
this
;
}
#endif // USE_EIGEN
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
float16_t
*
rhs
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
rhs
);
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
float16_t
&
rhs
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
&
rhs
);
return
*
this
;
}
#endif
...
...
@@ -319,17 +281,14 @@ struct PADDLE_ALIGN(2) float16 {
}
#endif // PADDLE_CUDA_FP16
#ifdef USE_EIGEN
PADDLE_HOSTDEVICE
inline
operator
Eigen
::
half
()
const
{
Eigen
::
half
h
;
h
.
x
=
x
;
return
h
;
}
#endif // USE_EIGEN
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
// check whether it works or not
PADDLE_HOSTDEVICE
inline
operator
float16_t
()
const
{
float16
h
=
*
this
;
return
*
reinterpret_cast
<
float16_t
*>
(
&
h
);
...
...
@@ -381,10 +340,9 @@ struct PADDLE_ALIGN(2) float16 {
}
};
//
a
rithmetic operators
//
A
rithmetic operators
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
printf
(
"GPU Intrinsic used!"
);
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)));
}
...
...
@@ -452,7 +410,7 @@ __device__ inline bool operator>=(const float16& a, const float16& b) {
}
// On ARMv8.2-A CPU
#elif defined(PADDLE_NEON
_64
) && defined(PADDLE_ARM_FP16) && \
#elif defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39)
__host__
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
vaddh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
...
...
@@ -502,7 +460,7 @@ __host__ inline bool operator!=(const float16& a, const float16& b) {
return
!
(
a
==
b
);
}
// compare only available in
NEON_64
#ifdef PADDLE_
NEON_64
__host__
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vclth_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
...
...
@@ -518,10 +476,10 @@ __host__ inline bool operator>(const float16& a, const float16& b) {
__host__
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vcgeh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
#endif // PADDLE_NEON_64
#else //
s
oftware emulation on other cpu
#else //
S
oftware emulation on other cpu
PADDLE_HOSTDEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
LOG
(
INFO
)
<<
"CPU emulation used"
;
return
float16
(
float
(
a
)
+
float
(
b
));
}
...
...
@@ -624,7 +582,7 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
half
tmp
=
__float2half
(
f
);
return
*
reinterpret_cast
<
float16
*>
(
&
tmp
);
#elif defined(PADDLE_NEON_64)
// test on RPI
#elif defined(PADDLE_NEON_64)
float16
res
;
asm
volatile
(
"ld1 {v0.s}[0], [%[float_ptr]]
\n
"
...
...
@@ -638,7 +596,7 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
"memory"
,
"v0"
);
return
res
;
#elif defined(PADDLE_NEON_32)
// test on RPI
#elif defined(PADDLE_NEON_32)
float16
res
;
asm
volatile
(
"vld1.32 {d0[0]}, [%[float_ptr]]
\n
"
...
...
@@ -689,7 +647,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
float
res
;
asm
volatile
(
"ld1 {v0.h}[0], [%[half_ptr]]
\n
"
"
FCVT
s0, h0
\n
"
"
fcvt
s0, h0
\n
"
"st1 {v0.s}[0], [%[float_ptr]]
\n
"
:
// outputs
:
// inputs
...
...
@@ -739,5 +697,4 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
}
}
// namespace fp16_impl
}
// namespace paddle
paddle/math/tests/test_float16.cpp
浏览文件 @
0f4bf1c9
...
...
@@ -9,22 +9,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/math/float16.h"
#include <gtest/gtest.h>
namespace
paddle
{
TEST
(
float16
,
conversion_cpu
)
{
LOG
(
INFO
)
<<
"cpu test started!"
;
// Conversion to and from Eigen::half
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
1.0
f
))).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
0.5
f
))).
x
,
0x3800
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
0.33333
f
))).
x
,
0x3555
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
0.0
f
))).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
-
0.0
f
))).
x
,
0x8000
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
65504.0
f
))).
x
,
0x7bff
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
float16
(
65536.0
f
))).
x
,
0x7c00
);
// Explicit conversion from Eigen::half
EXPECT_EQ
(
float16
(
Eigen
::
half
(
1.0
f
)).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
0.5
f
)).
x
,
0x3800
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
0.33333
f
)).
x
,
0x3555
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
0.0
f
)).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
-
0.0
f
)).
x
,
0x8000
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
65504.0
f
)).
x
,
0x7bff
);
EXPECT_EQ
(
float16
(
Eigen
::
half
(
65536.0
f
)).
x
,
0x7c00
);
// Conversion from float
EXPECT_EQ
(
float16
(
1.0
f
).
x
,
0x3c00
);
...
...
@@ -36,14 +35,91 @@ TEST(float16, conversion_cpu) {
EXPECT_EQ
(
float16
(
65536.0
f
).
x
,
0x7c00
);
// Conversion from double
EXPECT_EQ
(
float16
(
1.0
).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
0.5
).
x
,
0x3800
);
EXPECT_EQ
(
float16
(
0.33333
).
x
,
0x3555
);
EXPECT_EQ
(
float16
(
0.0
).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
-
0.0
).
x
,
0x8000
);
EXPECT_EQ
(
float16
(
65504.0
).
x
,
0x7bff
);
EXPECT_EQ
(
float16
(
65536.0
).
x
,
0x7c00
);
// Conversion from int
EXPECT_EQ
(
float16
(
-
1
).
x
,
0xbc00
);
EXPECT_EQ
(
float16
(
0
).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
1
).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
2
).
x
,
0x4000
);
EXPECT_EQ
(
float16
(
3
).
x
,
0x4200
);
// Conversion from bool
EXPECT_EQ
(
float16
(
true
).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
false
).
x
,
0x0000
);
// Implicit conversion to and from Eigen::half
Eigen
::
half
tmp
=
float16
(
1.0
f
);
float16
v_conv
=
tmp
;
EXPECT_EQ
(
tmp
.
x
,
0x3c00
);
EXPECT_EQ
(
v_conv
.
x
,
0x3c00
);
// Default constructor
float16
v_def
;
EXPECT_EQ
(
v_def
.
x
,
0x0000
);
// Assignment operator
float16
v_assign
;
v_assign
=
v_def
;
EXPECT_EQ
(
v_assign
.
x
,
0x0000
);
v_assign
=
Eigen
::
half
(
1.0
f
);
EXPECT_EQ
(
v_assign
.
x
,
0x3c00
);
v_assign
=
0.5
f
;
EXPECT_EQ
(
v_assign
.
x
,
0x3800
);
v_assign
=
0.33333
;
EXPECT_EQ
(
v_assign
.
x
,
0x3555
);
v_assign
=
-
1
;
EXPECT_EQ
(
v_assign
.
x
,
0xbc00
);
v_assign
=
true
;
EXPECT_EQ
(
v_assign
.
x
,
0x3c00
);
// Conversion operator
EXPECT_EQ
(
Eigen
::
half
(
float16
(
1.0
f
)).
x
,
0x3c00
);
EXPECT_EQ
(
float
(
float16
(
0.5
f
)),
0.5
f
);
EXPECT_NEAR
(
double
(
float16
(
0.33333
)),
0.33333
,
0.0001
);
EXPECT_EQ
(
int
(
float16
(
-
1
)),
-
1
);
EXPECT_EQ
(
bool
(
float16
(
true
)),
true
);
}
TEST
(
float16
,
arithmetic_cpu
)
{
EXPECT_EQ
(
float
(
float16
(
2
)
+
float16
(
2
)),
4
);
}
TEST
(
float16
,
arithmetic_cpu
)
{
EXPECT_EQ
(
float
(
float16
(
1
)
+
float16
(
1
)),
2
);
EXPECT_EQ
(
float
(
float16
(
5
)
+
float16
(
-
5
)),
0
);
EXPECT_NEAR
(
float
(
float16
(
0.33333
f
)
+
float16
(
0.66667
f
)),
1.0
f
,
0.001
);
EXPECT_EQ
(
float
(
float16
(
3
)
-
float16
(
5
)),
-
2
);
EXPECT_NEAR
(
float
(
float16
(
0.66667
f
)
-
float16
(
0.33333
f
)),
0.33334
f
,
0.001
);
EXPECT_NEAR
(
float
(
float16
(
3.3
f
)
*
float16
(
2.0
f
)),
6.6
f
,
0.01
);
EXPECT_NEAR
(
float
(
float16
(
-
2.1
f
)
*
float16
(
-
3.0
f
)),
6.3
f
,
0.01
);
EXPECT_NEAR
(
float
(
float16
(
2.0
f
)
/
float16
(
3.0
f
)),
0.66667
f
,
0.001
);
EXPECT_EQ
(
float
(
float16
(
1.0
f
)
/
float16
(
2.0
f
)),
0.5
f
);
EXPECT_EQ
(
float
(
-
float16
(
512.0
f
)),
-
512.0
f
);
EXPECT_EQ
(
float
(
-
float16
(
-
512.0
f
)),
512.0
f
);
}
TEST
(
float16
,
comparison_cpu
)
{
EXPECT_TRUE
(
float16
(
1.0
f
)
>
float16
(
0.5
f
));
}
TEST
(
float16
,
comparison_cpu
)
{
EXPECT_TRUE
(
float16
(
1.0
f
)
==
float16
(
1.0
f
));
EXPECT_FALSE
(
float16
(
-
1.0
f
)
==
float16
(
-
0.5
f
));
EXPECT_TRUE
(
float16
(
1.0
f
)
!=
float16
(
0.5
f
));
EXPECT_FALSE
(
float16
(
-
1.0
f
)
!=
float16
(
-
1.0
f
));
EXPECT_TRUE
(
float16
(
1.0
f
)
<
float16
(
2.0
f
));
EXPECT_FALSE
(
float16
(
-
1.0
f
)
<
float16
(
-
1.0
f
));
EXPECT_TRUE
(
float16
(
1.0
f
)
<=
float16
(
1.0
f
));
EXPECT_TRUE
(
float16
(
2.0
f
)
>
float16
(
1.0
f
));
EXPECT_FALSE
(
float16
(
-
2.0
f
)
>
float16
(
-
2.0
f
));
EXPECT_TRUE
(
float16
(
2.0
f
)
>=
float16
(
2.0
f
));
EXPECT_TRUE
(
float16
(
0.0
f
)
==
float16
(
-
0.0
f
));
EXPECT_TRUE
(
float16
(
0.0
f
)
<=
float16
(
-
0.0
f
));
EXPECT_TRUE
(
float16
(
0.0
f
)
>=
float16
(
-
0.0
f
));
EXPECT_FALSE
(
float16
(
0.0
f
)
<
float16
(
-
0.0
f
));
EXPECT_FALSE
(
float16
(
-
0.0
f
)
<
float16
(
0.0
f
));
EXPECT_FALSE
(
float16
(
0.0
f
)
>
float16
(
-
0.0
f
));
EXPECT_FALSE
(
float16
(
-
0.0
f
)
>
float16
(
0.0
f
));
}
}
// namespace paddle
paddle/math/tests/test_float16.cu
浏览文件 @
0f4bf1c9
...
...
@@ -9,42 +9,211 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/math/float16.h"
namespace
paddle
{
#include <gtest/gtest.h>
#include "paddle/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type( \
const float16* in1, const float16* in2, float16* out) { \
out[0] = in1[0] sign in2[0]; \
}
#define COMPOUND_KERNEL(op_type, sign) \
__global__ void op_type(float16* in1, const float16* in2) { \
in1[0] sign in2[0]; \
}
#define COMPARISON_KERNEL(op_type, sign) \
__global__ void op_type(const float16* in1, const float16* in2, bool* out) { \
out[0] = in1[0] sign in2[0]; \
}
#define ARITHMETIC_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
float16 *in1, *in2, *out; \
float16 *d_in1, *d_in2, *d_out; \
int size = sizeof(float16); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
cudaMalloc((void**)&d_out, size); \
in1 = (float16*)malloc(size); \
in2 = (float16*)malloc(size); \
out = (float16*)malloc(size); \
in1[0] = float16(v_in1); \
in2[0] = float16(v_in2); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2, d_out); \
cudaMemcpy(out, d_out, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(float(out[0]), v_out); \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
cudaFree(d_out); \
}
#define COMPOUND_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
float16 *in1, *in2; \
float16 *d_in1, *d_in2; \
int size = sizeof(float16); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
in1 = (float16*)malloc(size); \
in2 = (float16*)malloc(size); \
in1[0] = float16(v_in1); \
in2[0] = float16(v_in2); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2); \
cudaMemcpy(in1, d_in1, size, cudaMemcpyDeviceToHost); \
EXPECT_EQ(float(in1[0]), v_out); \
free(in1); \
free(in2); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
#define COMPARISON_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, bool v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
float16 *in1, *in2; \
float16 *d_in1, *d_in2; \
bool *out, *d_out; \
int size = sizeof(float16); \
cudaMalloc((void**)&d_in1, size); \
cudaMalloc((void**)&d_in2, size); \
cudaMalloc((void**)&d_out, 1); \
in1 = (float16*)malloc(size); \
in2 = (float16*)malloc(size); \
out = (bool*)malloc(1); \
in1[0] = float16(v_in1); \
in2[0] = float16(v_in2); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op_type<<<1, 1>>>(d_in1, d_in2, d_out); \
cudaMemcpy(out, d_out, 1, cudaMemcpyDeviceToHost); \
EXPECT_EQ(out[0], v_out); \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
cudaFree(d_out); \
}
#ifdef PADDLE_CUDA_FP16
TEST
(
float16
,
conversion_gpu
)
{
LOG
(
INFO
)
<<
"GPU tests started"
;
namespace
paddle
{
// Conversion to and from cuda half
float16
v1
=
half
(
float16
(
1.0
f
));
EXPECT_EQ
(
v1
.
x
,
0x3c00
);
ARITHMETIC_KERNEL
(
Add
,
+
)
ARITHMETIC_KERNEL
(
Sub
,
-
)
ARITHMETIC_KERNEL
(
Mul
,
*
)
ARITHMETIC_KERNEL
(
Div
,
/
)
// Conversion to and from Eigen::half
float16
v2
=
Eigen
::
half
(
float16
(
0.5
f
));
EXPECT_EQ
(
v2
.
x
,
0x3800
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
)
ARITHMETIC_KERNEL_LAUNCH
(
Sub
)
ARITHMETIC_KERNEL_LAUNCH
(
Mul
)
ARITHMETIC_KERNEL_LAUNCH
(
Div
)
// Conversion from float
EXPECT_EQ
(
float16
(
1.0
f
).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
0.5
f
).
x
,
0x3800
);
EXPECT_EQ
(
float16
(
0.33333
f
).
x
,
0x3555
);
EXPECT_EQ
(
float16
(
0.0
f
).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
-
0.0
f
).
x
,
0x8000
);
EXPECT_EQ
(
float16
(
65504.0
f
).
x
,
0x7bff
);
EXPECT_EQ
(
float16
(
65536.0
f
).
x
,
0x7c00
);
// Negative sign kernel
__global__
void
Neg
(
float16
*
in
)
{
in
[
0
]
=
-
in
[
0
];
}
// Conversion from double
void
TestNeg
(
float
v_in
,
float
v_out
)
{
LOG
(
INFO
)
<<
"Test Neg on GPU!"
;
float16
*
in
,
*
d_in
;
int
size
=
sizeof
(
float16
);
cudaMalloc
((
void
**
)
&
d_in
,
size
);
in
=
(
float16
*
)
malloc
(
size
);
in
[
0
]
=
float16
(
v_in
);
cudaMemcpy
(
d_in
,
in
,
size
,
cudaMemcpyHostToDevice
);
Neg
<<<
1
,
1
>>>
(
d_in
);
cudaMemcpy
(
in
,
d_in
,
size
,
cudaMemcpyDeviceToHost
);
EXPECT_EQ
(
float
(
in
[
0
]),
v_out
);
free
(
in
);
cudaFree
(
d_in
);
}
// Conversion from int
COMPOUND_KERNEL
(
AddAssign
,
+=
)
COMPOUND_KERNEL
(
SubAssign
,
-=
)
COMPOUND_KERNEL
(
MulAssign
,
*=
)
COMPOUND_KERNEL
(
DivAssign
,
/=
)
// Conversion from bool
COMPOUND_KERNEL_LAUNCH
(
AddAssign
)
COMPOUND_KERNEL_LAUNCH
(
SubAssign
)
COMPOUND_KERNEL_LAUNCH
(
MulAssign
)
COMPOUND_KERNEL_LAUNCH
(
DivAssign
)
COMPARISON_KERNEL
(
Equal
,
==
)
COMPARISON_KERNEL
(
NotEqual
,
!=
)
COMPARISON_KERNEL
(
Less
,
<
)
COMPARISON_KERNEL
(
LessEqual
,
<=
)
COMPARISON_KERNEL
(
Greater
,
>
)
COMPARISON_KERNEL
(
GreaterEqual
,
>=
)
COMPARISON_KERNEL_LAUNCH
(
Equal
)
COMPARISON_KERNEL_LAUNCH
(
NotEqual
)
COMPARISON_KERNEL_LAUNCH
(
Less
)
COMPARISON_KERNEL_LAUNCH
(
LessEqual
)
COMPARISON_KERNEL_LAUNCH
(
Greater
)
COMPARISON_KERNEL_LAUNCH
(
GreaterEqual
)
TEST
(
float16
,
arithmetic_on_gpu
)
{
TestAdd
(
1
,
2
,
3
);
TestSub
(
2
,
1
,
1
);
TestMul
(
2
,
3
,
6
);
TestDiv
(
6
,
2
,
3
);
TestNeg
(
1
,
-
1
);
}
#endif
TEST
(
float16
,
arithmetic_gpu
)
{
EXPECT_EQ
(
float
(
float16
(
2
)
+
float16
(
2
)),
4
);
}
TEST
(
float16
,
compound_on_gpu
)
{
TestAddAssign
(
1
,
2
,
3
);
TestSubAssign
(
2
,
1
,
1
);
TestMulAssign
(
2
,
3
,
6
);
TestDivAssign
(
6
,
2
,
3
);
}
TEST
(
float16
,
comparison_gpu
)
{
EXPECT_TRUE
(
float16
(
1.0
f
)
>
float16
(
0.5
f
));
}
TEST
(
float16
,
comparision_on_gpu
)
{
TestEqual
(
1
,
1
,
true
);
TestEqual
(
1
,
2
,
false
);
TestNotEqual
(
2
,
3
,
true
);
TestNotEqual
(
2
,
2
,
false
);
TestLess
(
3
,
4
,
true
);
TestLess
(
3
,
3
,
false
);
TestLessEqual
(
3
,
3
,
true
);
TestLessEqual
(
3
,
2
,
false
);
TestGreater
(
4
,
3
,
true
);
TestGreater
(
4
,
4
,
false
);
TestGreaterEqual
(
4
,
4
,
true
);
TestGreaterEqual
(
4
,
5
,
false
);
}
TEST
(
float16
,
conversion_on_gpu
)
{
// Explicit conversion to and from cuda half
EXPECT_EQ
(
float16
(
half
(
float16
(
1.0
f
))).
x
,
0x3c00
);
EXPECT_EQ
(
float16
(
half
(
float16
(
0.5
f
))).
x
,
0x3800
);
EXPECT_EQ
(
float16
(
half
(
float16
(
0.33333
f
))).
x
,
0x3555
);
EXPECT_EQ
(
float16
(
half
(
float16
(
0.0
f
))).
x
,
0x0000
);
EXPECT_EQ
(
float16
(
half
(
float16
(
-
0.0
f
))).
x
,
0x8000
);
EXPECT_EQ
(
float16
(
half
(
float16
(
65504.0
f
))).
x
,
0x7bff
);
EXPECT_EQ
(
float16
(
half
(
float16
(
65536.0
f
))).
x
,
0x7c00
);
// Implicit conversion to and from cuda half
half
tmp
=
float16
(
1.0
f
);
float16
val
=
tmp
;
EXPECT_EQ
(
val
.
x
,
0x3c00
);
// Assignment operator
float16
v_assign
;
v_assign
=
tmp
;
EXPECT_EQ
(
v_assign
.
x
,
0x3c00
);
}
}
// namespace paddle
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录