Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e0b5691e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0b5691e
编写于
3月 27, 2018
作者:
G
gongweibao
提交者:
GitHub
3月 27, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add drop_out_op unit test (#9364)
上级
fe142d92
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
99 addition
and
3 deletion
+99
-3
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+2
-3
paddle/fluid/operators/dropout_op_test.cc
paddle/fluid/operators/dropout_op_test.cc
+96
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
e0b5691e
...
@@ -264,3 +264,4 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memor
...
@@ -264,3 +264,4 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memor
cc_test
(
save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op
)
cc_test
(
save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op
)
cc_test
(
save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op
)
cc_test
(
save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op
)
nv_test
(
nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context
)
nv_test
(
nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context
)
nv_test
(
dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor
)
paddle/fluid/operators/dropout_op.cu
浏览文件 @
e0b5691e
...
@@ -55,9 +55,6 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
...
@@ -55,9 +55,6 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
@@ -76,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
...
@@ -76,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
}
else
{
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
}
...
...
paddle/fluid/operators/dropout_op_test.cc
0 → 100644
浏览文件 @
e0b5691e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
m
=
paddle
::
operators
::
math
;
USE_OP
(
dropout
);
void
Compare
(
f
::
Scope
&
scope
,
p
::
DeviceContext
&
ctx
)
{
// init
auto
var
=
scope
.
Var
(
"X"
);
auto
tensor
=
var
->
GetMutable
<
f
::
LoDTensor
>
();
tensor
->
Resize
({
10
,
10
});
std
::
vector
<
float
>
init
;
for
(
int64_t
i
=
0
;
i
<
10
*
10
;
++
i
)
{
init
.
push_back
(
1.0
);
}
TensorFromVector
(
init
,
ctx
,
tensor
);
auto
place
=
ctx
.
GetPlace
();
auto
out_var
=
scope
.
Var
(
"Out"
);
auto
out_tensor
=
out_var
->
GetMutable
<
f
::
LoDTensor
>
();
out_tensor
->
Resize
({
10
,
10
});
out_tensor
->
mutable_data
<
float
>
(
place
);
// allocate
auto
mask_var
=
scope
.
Var
(
"Mask"
);
auto
mask_tensor
=
mask_var
->
GetMutable
<
f
::
LoDTensor
>
();
mask_tensor
->
Resize
({
10
,
10
});
mask_tensor
->
mutable_data
<
float
>
(
place
);
// allocate
// run
f
::
AttributeMap
attrs
;
float
dropout_prob
=
0.5
;
attrs
.
insert
({
"fix_seed"
,
1
});
attrs
.
insert
({
"seed"
,
3
});
attrs
.
insert
({
"dropout_prob"
,
dropout_prob
});
auto
dropout_op
=
f
::
OpRegistry
::
CreateOp
(
"dropout"
,
{{
"X"
,
{
"X"
}}},
{{
"Out"
,
{
"Out"
}},
{
"Mask"
,
{
"Mask"
}}},
attrs
);
dropout_op
->
Run
(
scope
,
place
);
std
::
vector
<
float
>
out_vec
;
TensorToVector
(
*
out_tensor
,
ctx
,
&
out_vec
);
std
::
vector
<
float
>
std_out
=
{
0
,
0
,
1
,
1
,
1
,
1
,
1
,
0
,
1
,
0
,
0
,
1
,
1
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
1
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
1
,
1
,
1
,
0
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
1
,
0
,
1
,
1
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
1
,
1
};
EXPECT_EQ
(
out_vec
.
size
(),
std_out
.
size
());
for
(
uint32_t
i
=
0
;
i
<
out_vec
.
size
();
i
++
)
{
EXPECT_EQ
(
out_vec
[
i
],
std_out
[
i
]);
}
}
TEST
(
Dropout
,
CPUDense
)
{
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUDeviceContext
ctx
(
place
);
Compare
(
scope
,
ctx
);
}
TEST
(
Dropout
,
GPUDense
)
{
f
::
Scope
scope
;
p
::
CUDAPlace
place
;
p
::
CUDADeviceContext
ctx
(
place
);
Compare
(
scope
,
ctx
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录