Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f2bbbb2b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2bbbb2b
编写于
3月 19, 2018
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix arithmetic operator
上级
18d616ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
68 addition
and
33 deletion
+68
-33
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+68
-33
未找到文件。
paddle/fluid/platform/float16.h
浏览文件 @
f2bbbb2b
...
...
@@ -484,72 +484,107 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
DEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
+
float
(
b
));
}
DEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hsub
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
-
float
(
b
));
}
DEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hmul
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
*
float
(
b
));
}
DEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
HOSTDEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float
num
=
__half2float
(
half
(
a
));
float
denom
=
__half2float
(
half
(
b
));
return
float16
(
num
/
denom
);
#else
return
float16
(
float
(
a
)
/
float
(
b
));
}
DEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hneg
(
half
(
a
)));
#else
float16
res
;
res
.
x
=
a
.
x
^
0x8000
;
return
res
;
}
DEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
DEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
+
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
DEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
-
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
DEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
*
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
DEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
/
b
;
return
a
;
}
DEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__heq
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
==
float
(
b
);
}
DEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hne
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
!=
float
(
b
);
}
DEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hlt
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
<
float
(
b
);
}
DEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hle
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
<=
float
(
b
);
}
DEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hgt
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
>
float
(
b
);
}
DEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOSTDEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hge
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
>=
float
(
b
);
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
...
...
@@ -737,71 +772,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return
(
res
&
0xffff
)
!=
0
;
}
// Arithmetic operators for float16, software emulated on other CPU
/GPU
// Arithmetic operators for float16, software emulated on other CPU
#else
HOST
DEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
+
float
(
b
));
}
HOST
DEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
-
float
(
b
));
}
HOST
DEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
*
float
(
b
));
}
HOST
DEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
/
float
(
b
));
}
HOST
DEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
HOST
inline
float16
operator
-
(
const
float16
&
a
)
{
float16
res
;
res
.
x
=
a
.
x
^
0x8000
;
return
res
;
}
HOST
DEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
+
float
(
b
));
return
a
;
}
HOST
DEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
-
float
(
b
));
return
a
;
}
HOST
DEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
*
float
(
b
));
return
a
;
}
HOST
DEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
/
float
(
b
));
return
a
;
}
HOST
DEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
==
float
(
b
);
}
HOST
DEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
!=
float
(
b
);
}
HOST
DEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<
float
(
b
);
}
HOST
DEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<=
float
(
b
);
}
HOST
DEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>
float
(
b
);
}
HOST
DEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>=
float
(
b
);
}
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录