Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e877cdb8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e877cdb8
编写于
11月 13, 2017
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add float16 arithmetic on arm cpu
上级
9d8b3059
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
389 addition
and
90 deletion
+389
-90
paddle/math/float16.h
paddle/math/float16.h
+389
-90
未找到文件。
paddle/math/float16.h
浏览文件 @
e877cdb8
...
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// need to define PADDLE_ARM_FP16
#pragma once
#include <cstdint>
...
...
@@ -24,6 +26,18 @@ limitations under the License. */
#include "Eigen/src/Core/arch/CUDA/Half.h"
#endif
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif // __GNUC__
#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif // __clang__
#ifdef __CUDACC__
#define PADDLE_HOSTDEVICE __host__ __device__
#if CUDA_VERSION >= 7050
...
...
@@ -48,6 +62,7 @@ limitations under the License. */
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
#include <arm_neon.h>
#endif
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_32)
...
...
@@ -58,26 +73,16 @@ limitations under the License. */
#define PADDLE_NEON_64
#endif
#if defined(PADDLE_ARM) && defined(PADDLE_NEON)
#include <arm_neon.h>
#endif
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(PADDLE_ARM)
#include <immintrin.h>
#else
#ifdef PADDLE_ARM
#ifdef __F16C__
#undef __F16C__
#endif
#endif
#endif // __F16C__
#else
#include <immintrin.h>
#endif // PADDLE_ARM
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
// https://github.com/pytorch/pytorch/blob/master/torch/lib/ATen/Half.h
template
<
typename
To
,
typename
From
>
To
convert
(
From
f
)
{
return
static_cast
<
To
>
(
f
);
}
namespace
paddle
{
struct
float16
;
...
...
@@ -86,13 +91,12 @@ namespace fp16_impl {
// convert from float to half precision in round-to-nearest-even mode
PADDLE_HOSTDEVICE
inline
float16
float_to_half_rn
(
float
f
);
PADDLE_HOSTDEVICE
inline
float
half_to_float
(
float16
h
);
PADDLE_HOSTDEVICE
inline
float16
uint16_to_half
(
uint16_t
x
);
}
// namespace fp16_impl
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
// with CUDA half and Eigen::half data types.
// with CUDA half
, ARM float16_t,
and Eigen::half data types.
struct
PADDLE_ALIGN
(
2
)
float16
{
uint16_t
x
;
...
...
@@ -103,7 +107,7 @@ struct PADDLE_ALIGN(2) float16 {
PADDLE_HOSTDEVICE
inline
float16
(
const
float16
&
h
)
:
x
(
h
.
x
)
{}
#ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE
inline
float16
(
const
half
h
)
{
PADDLE_HOSTDEVICE
inline
float16
(
const
half
&
h
)
{
#if CUDA_VERSION >= 9000
x
=
reinterpret_cast
<
__half_raw
*>
(
&
h
)
->
x
;
#else
...
...
@@ -111,40 +115,72 @@ struct PADDLE_ALIGN(2) float16 {
#endif // CUDA_VERSION >= 9000
}
#endif // PADDLE_CUDA_FP16
/*
#ifdef PADDLE_CUDA_FP16
#if CUDA_VERSION < 9000
PADDLE_HOSTDEVICE inline float16(const half& h) : x(h.x) {}
#else
PADDLE_HOSTDEVICE inline float16(const __half_raw& h) : x(h.x) {}
PADDLE_HOSTDEVICE inline float16(const half& h)
: x(*reinterpret_cast<uint16_t*>(&h)) {}
#endif // CUDA_VERSION < 9000
#endif // PADDLE_CUDA_FP16
*/
#ifdef USE_EIGEN
PADDLE_HOSTDEVICE
inline
float16
(
const
Eigen
::
half
&
h
)
:
x
(
h
.
x
)
{}
#endif // USE_EIGEN
#if
defined(PADDLE_ARM) && defined(PADDLE_NEON)
#if
def PADDLE_NEON
// __fp16 is a native half precision data type for arm cpu,
// float16_t is an alias for __fp16 in arm_fp16.h
// which is included in arm_neon.h
PADDLE_HOSTDEVICE
inline
float16
(
const
float16_t
h
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
&
h
);
// float16_t is an alias for __fp16 in arm_fp16.h,
// which is included in arm_neon.h.
// According to gcc, __fp16 can only be used as an argument to fp16
// intrinsic defined in arm_neon.h or as a storage type. It cannot
// be used as a formal function argument.
// TODO (kexinzhao): test it on RPI
PADDLE_HOSTDEVICE
inline
float16
(
const
float16_t
*
h
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
h
);
}
#endif
PADDLE_HOSTDEVICE
inline
explicit
float16
(
bool
b
)
:
x
(
b
?
0x3c00
:
0
)
{}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
int8_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
uint8_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
int16_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
uint16_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
int32_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
uint32_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
int64_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
uint64_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
PADDLE_HOSTDEVICE
inline
explicit
float16
(
float
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
val
);
x
=
res
.
x
;
}
template
<
class
T
>
PADDLE_HOSTDEVICE
inline
explicit
float16
(
const
T
&
val
)
{
PADDLE_HOSTDEVICE
inline
explicit
float16
(
double
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
}
...
...
@@ -155,7 +191,7 @@ struct PADDLE_ALIGN(2) float16 {
}
#ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
half
rhs
)
{
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
half
&
rhs
)
{
#if CUDA_VERSION >= 9000
x
=
reinterpret_cast
<
__half_raw
*>
(
&
rhs
)
->
x
;
#else
...
...
@@ -172,27 +208,80 @@ struct PADDLE_ALIGN(2) float16 {
}
#endif // USE_EIGEN
#if
defined(PADDLE_ARM) && defined(PADDLE_NEON)
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
float16_t
rhs
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
&
rhs
);
#if
def PADDLE_NEON
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
const
float16_t
*
rhs
)
{
x
=
*
reinterpret_cast
<
uint16_t
*>
(
rhs
);
return
*
this
;
}
#endif
/*
PADDLE_HOSTDEVICE inline explicit float16(int val) {
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
bool
b
)
{
x
=
b
?
0x3c00
:
0
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
int8_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE inline
explicit float16(double
val) {
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
uint8_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
int16_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
uint16_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
int32_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
uint32_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
int64_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
uint64_t
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
float
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
val
);
x
=
res
.
x
;
return
*
this
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
=
(
double
val
)
{
float16
res
=
fp16_impl
::
float_to_half_rn
(
static_cast
<
float
>
(
val
));
x
=
res
.
x
;
return
*
this
;
}
*/
#ifdef PADDLE_CUDA_FP16
PADDLE_HOSTDEVICE
inline
operator
half
()
{
PADDLE_HOSTDEVICE
inline
operator
half
()
const
{
#if CUDA_VERSION >= 9000
__half_raw
h
;
h
.
x
=
x
;
...
...
@@ -206,82 +295,270 @@ struct PADDLE_ALIGN(2) float16 {
#endif // PADDLE_CUDA_FP16
#ifdef USE_EIGEN
PADDLE_HOSTDEVICE
inline
operator
Eigen
::
half
()
{
PADDLE_HOSTDEVICE
inline
operator
Eigen
::
half
()
const
{
Eigen
::
half
h
;
h
.
x
=
x
;
return
h
;
}
#endif // USE_EIGEN
#if defined(PADDLE_ARM) && defined(PADDLE_NEON)
PADDLE_HOSTDEVICE
inline
operator
float16_t
()
{
#ifdef PADDLE_NEON
// check whether it works or not
PADDLE_HOSTDEVICE
inline
operator
float16_t
()
const
{
float16
h
=
*
this
;
return
*
reinterpret_cast
<
float16_t
*>
(
&
h
);
}
#endif
PADDLE_HOSTDEVICE
inline
explicit
operator
bool
()
{
PADDLE_HOSTDEVICE
inline
explicit
operator
bool
()
const
{
return
(
x
&
0x7fff
)
!=
0
;
}
PADDLE_HOSTDEVICE
inline
explicit
operator
int8_t
()
{
return
static_cat
<
int8_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
int8_t
()
const
{
return
static_ca
s
t
<
int8_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
uint8_t
()
{
return
static_cat
<
uint8_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
uint8_t
()
const
{
return
static_ca
s
t
<
uint8_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
int16_t
()
{
return
static_cat
<
int16_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
int16_t
()
const
{
return
static_ca
s
t
<
int16_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
uint16_t
()
{
return
static_cat
<
uint16_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
uint16_t
()
const
{
return
static_ca
s
t
<
uint16_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
int32_t
()
{
return
static_cat
<
int32_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
int32_t
()
const
{
return
static_ca
s
t
<
int32_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
uint32_t
()
{
return
static_cat
<
uint32_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
uint32_t
()
const
{
return
static_ca
s
t
<
uint32_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
int64_t
()
{
return
static_cat
<
int64_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
int64_t
()
const
{
return
static_ca
s
t
<
int64_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
uint64_t
()
{
return
static_cat
<
uint64_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
uint64_t
()
const
{
return
static_ca
s
t
<
uint64_t
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
PADDLE_HOSTDEVICE
inline
explicit
operator
float
()
{
PADDLE_HOSTDEVICE
inline
explicit
operator
float
()
const
{
return
fp16_impl
::
half_to_float
(
*
this
);
}
PADDLE_HOSTDEVICE
inline
explicit
operator
double
()
{
return
static_cat
<
double
>
(
fp16_impl
::
half_to_float
(
*
this
));
PADDLE_HOSTDEVICE
inline
explicit
operator
double
()
const
{
return
static_ca
s
t
<
double
>
(
fp16_impl
::
half_to_float
(
*
this
));
}
};
// arithmetic operators
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hadd
(
a
,
b
));
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)
));
}
__device__
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hsub
(
a
,
b
);
return
float16
(
__hsub
(
half
(
a
),
half
(
b
))
);
}
__device__
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hmul
(
a
,
b
);
return
float16
(
__hmul
(
half
(
a
),
half
(
b
))
);
}
#elif // on arm cpu
__device__
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
// instinsic
float
num
=
__half2float
(
half
(
a
));
float
denom
=
__half2float
(
half
(
b
));
return
float16
(
num
/
denom
);
}
#else
__device__
inline
float16
operator
-
(
const
float16
&
a
)
{
return
float16
(
__hneg
(
half
(
a
)));
}
__device__
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
+
b
;
return
a
;
}
__device__
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
-
b
;
return
a
;
}
__device__
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
*
b
;
return
a
;
}
__device__
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
/
b
;
return
a
;
}
__device__
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__heq
(
half
(
a
),
half
(
b
));
}
__device__
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hne
(
half
(
a
),
half
(
b
));
}
__device__
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hlt
(
half
(
a
),
half
(
b
));
}
__device__
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hle
(
half
(
a
),
half
(
b
));
}
__device__
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hgt
(
half
(
a
),
half
(
b
));
}
__device__
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hge
(
half
(
a
),
half
(
b
));
}
// On ARMv8.2-A CPU
#elif (PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39) && \
defined(PADDLE_NEON_64) && defined(PADDLE_ARM_FP16)
__host__
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
vaddh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
vsubh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
vmulh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
vdivh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
float16
operator
-
(
const
float16
&
a
)
{
return
float16
(
vnegh_f16
(
float16_t
(
a
)));
}
__host__
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
+
b
;
return
a
;
}
__host__
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
-
b
;
return
a
;
}
__host__
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
*
b
;
return
a
;
}
__host__
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
/
b
;
return
a
;
}
__host__
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vceqh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
!
(
a
==
b
);
}
// compare only available in NEON_64
__host__
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vclth_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vcleh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vcgth_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
__host__
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
bool
>
(
vcgeh_f16
(
float16_t
(
a
),
float16_t
(
b
)));
}
#else // software emulation on other cpu
PADDLE_HOSTDEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
+
float
(
b
));
}
PADDLE_HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
-
float
(
b
));
}
PADDLE_HOSTDEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
*
float
(
b
));
}
PADDLE_HOSTDEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
/
float
(
b
));
}
PADDLE_HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
float16
res
;
res
.
x
=
a
.
x
^
0x8000
;
return
res
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
+
float
(
b
));
return
a
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
-
float
(
b
));
return
a
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
*
float
(
b
));
return
a
;
}
PADDLE_HOSTDEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
/
float
(
b
));
return
a
;
}
PADDLE_HOSTDEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
==
float
(
b
);
}
PADDLE_HOSTDEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
!=
float
(
b
);
}
PADDLE_HOSTDEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<
float
(
b
);
}
PADDLE_HOSTDEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<=
float
(
b
);
}
PADDLE_HOSTDEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>
float
(
b
);
}
PADDLE_HOSTDEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>=
float
(
b
);
}
#endif
...
...
@@ -320,16 +597,11 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
half
tmp
=
__float2half
(
f
);
return
*
reinterpret_cast
<
float16
*>
(
&
(
tmp
));
#elif defined(__F16C__)
float16
res
;
res
.
x
=
_cvtss_sh
(
f
,
0
);
return
res
;
#elif defined(PADDLE_ARM_64) // test on RPI
#elif defined(PADDLE_NEON_64) // test on RPI
float16
res
;
asm
volatile
(
"ld1 {v0.s}[0], [%[float_ptr]]
\n
"
"
FCVT
h0, s0
\n
"
"
fcvt
h0, s0
\n
"
"st1 {v0.h}[0], [%[half_ptr]]
\n
"
:
// outputs
:
// inputs
...
...
@@ -339,6 +611,25 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
"memory"
,
"v0"
);
return
res
;
#elif defined(PADDLE_NEON_32) // test on RPI
float16
res
;
asm
volatile
(
"vld1.32 {d0[0]}, [%[float_ptr]]
\n
"
"vcvt.f16.f32 d0, q0
\n
"
"vst1.16 {d0[0]}, [%[half_ptr]]
\n
"
:
// outputs
:
// inputs
[
float_ptr
]
"r"
(
&
f
),
[
half_ptr
]
"r"
(
&
(
res
.
x
))
:
// clobbers
"memory"
,
"d0"
);
return
res
;
#elif defined(__F16C__)
float16
res
;
res
.
x
=
_cvtss_sh
(
f
,
0
);
return
res
;
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
...
...
@@ -367,10 +658,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
half
tmp
=
*
reinterpret_cast
<
half
*>
(
&
h
);
return
__half2float
(
h
);
#elif defined(__F16C__)
return
_cvtsh_ss
(
h
.
x
);
#elif defined(PADDLE_ARM_64) // test on RPI
#elif defined(PADDLE_NEON_64)
float
res
;
asm
volatile
(
"ld1 {v0.h}[0], [%[half_ptr]]
\n
"
...
...
@@ -384,6 +672,23 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
"memory"
,
"v0"
);
return
res
;
#elif defined(PADDLE_NEON_32)
float
res
;
asm
volatile
(
"vld1.16 {d0[0]}, [%[half_ptr]]
\n
"
"vcvt.f32.f16 q0, d0
\n
"
"vst1.32 {d0[0]}, [%[float_ptr]]
\n
"
:
// outputs
:
// inputs
[
half_ptr
]
"r"
(
&
(
h
.
x
)),
[
float_ptr
]
"r"
(
&
res
)
:
// clobbers
"memory"
,
"v0"
);
return
res
;
#elif defined(__F16C__)
return
_cvtsh_ss
(
h
.
x
);
#else
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
...
...
@@ -406,12 +711,6 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
#endif
}
PADDLE_HOSTDEVICE
inline
float16
uint16_to_half
(
uint16_t
x
)
{
float16
res
;
res
.
x
=
x
;
return
res
;
}
}
// namespace half_impl
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录