Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
18d616ed
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
18d616ed
编写于
3月 19, 2018
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add float16 arithmetic operators on new GPU
上级
d03dbb97
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
82 addition
and
7 deletion
+82
-7
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+72
-3
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+10
-4
未找到文件。
paddle/fluid/platform/float16.h
浏览文件 @
18d616ed
...
...
@@ -483,8 +483,77 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU
#if defined(PADDLE_WITH_NATIVE_FP16)
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
DEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)));
}
DEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hsub
(
half
(
a
),
half
(
b
)));
}
DEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hmul
(
half
(
a
),
half
(
b
)));
}
DEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
float
num
=
__half2float
(
half
(
a
));
float
denom
=
__half2float
(
half
(
b
));
return
float16
(
num
/
denom
);
}
DEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
return
float16
(
__hneg
(
half
(
a
)));
}
DEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
+
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
-
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
*
b
;
return
a
;
}
DEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
/
b
;
return
a
;
}
DEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__heq
(
half
(
a
),
half
(
b
));
}
DEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hne
(
half
(
a
),
half
(
b
));
}
DEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hlt
(
half
(
a
),
half
(
b
));
}
DEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hle
(
half
(
a
),
half
(
b
));
}
DEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hgt
(
half
(
a
),
half
(
b
));
}
DEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hge
(
half
(
a
),
half
(
b
));
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
float16
res
;
asm
volatile
(
...
...
@@ -668,7 +737,7 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return
(
res
&
0xffff
)
!=
0
;
}
// Arithmetic operators
, software emulated on other C
PU
// Arithmetic operators
for float16, software emulated on other CPU/G
PU
#else
HOSTDEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
+
float
(
b
));
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
18d616ed
...
...
@@ -86,10 +86,13 @@ class TestDropoutOp5(OpTest):
class
TestFP16DropoutOp1
(
OpTest
):
def
setUp
(
self
):
x
=
np
.
random
.
random
((
32
,
64
)).
astype
(
"float16"
)
prob
=
0.35
out
=
x
*
(
1.0
-
prob
)
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
x
*
(
1.0
-
self
.
attrs
[
'dropout_prob'
])
}
self
.
attrs
=
{
'dropout_prob'
:
prob
,
'fix_seed'
:
True
,
'is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"dropout"
):
...
...
@@ -99,10 +102,13 @@ class TestFP16DropoutOp1(OpTest):
class
TestFP16DropoutOp2
(
OpTest
):
def
setUp
(
self
):
x
=
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float16"
)
prob
=
0.75
out
=
x
*
(
1.0
-
prob
)
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
x
*
(
1.0
-
self
.
attrs
[
'dropout_prob'
])
}
self
.
attrs
=
{
'dropout_prob'
:
prob
,
'is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"dropout"
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录