Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e61d7245
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e61d7245
编写于
3月 19, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
3月 19, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the bug in fp16 backward kernel (#16266)
test=release/1.3
上级
c56d9026
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
33 addition
and
7 deletion
+33
-7
paddle/fluid/operators/slice_op.cu
paddle/fluid/operators/slice_op.cu
+7
-7
python/paddle/fluid/tests/unittests/test_slice_op.py
python/paddle/fluid/tests/unittests/test_slice_op.py
+26
-0
未找到文件。
paddle/fluid/operators/slice_op.cu
浏览文件 @
e61d7245
...
...
@@ -31,18 +31,18 @@ __global__ void Padding(const paddle::platform::float16* d_out,
paddle
::
platform
::
float16
*
d_in
)
{
int64_t
out_idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
out_idx
<
n
)
{
int64_t
out_idx_tmp
=
out_idx
;
int
coords
[
D
]
=
{
0
};
for
(
int
i
=
D
-
1
;
i
>=
0
;
--
i
)
{
coords
[
i
]
=
out_idx
%
out_dims
[
i
];
out_idx
/=
out_dims
[
i
];
coords
[
i
]
=
out_idx
_tmp
%
out_dims
[
i
];
out_idx
_tmp
/=
out_dims
[
i
];
coords
[
i
]
+=
offsets
[
i
];
}
int64_t
in_idx
=
0
;
for
(
int
i
=
0
;
i
<
D
-
1
;
++
i
)
{
in_idx
+=
coords
[
i
]
*
in_dims
[
i
+
1
];
for
(
int
i
=
0
;
i
<
D
;
++
i
)
{
in_idx
=
in_idx
*
in_dims
[
i
]
+
coords
[
i
];
}
in_idx
+=
coords
[
D
-
1
];
d_in
[
in_idx
]
=
d_out
[
out_idx
];
}
...
...
@@ -80,8 +80,8 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
set_zero
(
dev_ctx
,
d_in
,
static_cast
<
paddle
::
platform
::
float16
>
(
0
));
int64_t
numel
=
d_out
->
numel
();
dim3
blocks
((
numel
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
1
,
1
);
dim3
threads
(
PADDLE_CUDA_NUM_THREADS
,
1
,
1
);
dim3
blocks
((
numel
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
);
dim3
threads
(
PADDLE_CUDA_NUM_THREADS
);
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
out_shape
=
framework
::
vectorize2int
(
out_dims
);
...
...
python/paddle/fluid/tests/unittests/test_slice_op.py
浏览文件 @
e61d7245
...
...
@@ -87,5 +87,31 @@ class TestFP16(TestSliceOp):
place
,
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFP16_2
(
TestSliceOp
):
def
config
(
self
):
self
.
dtype
=
"float16"
self
.
input
=
np
.
random
.
random
([
3
,
4
,
5
]).
astype
(
self
.
dtype
)
self
.
starts
=
[
0
]
self
.
ends
=
[
1
]
self
.
axes
=
[
1
]
self
.
out
=
self
.
input
[:,
0
:
1
,
:]
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-5
)
def
test_check_grad_normal
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
numeric_grad_delta
=
0.5
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录