Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b1a18552
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1a18552
编写于
9月 03, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed SEGFAULT of dropout operator in GPU.
上级
9a44f3d6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
134 addition
and
21 deletion
+134
-21
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+4
-2
paddle/operators/dropout_op.cu
paddle/operators/dropout_op.cu
+2
-2
paddle/operators/dropout_op.h
paddle/operators/dropout_op.h
+81
-13
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-0
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+4
-4
python/paddle/v2/framework/tests/test_dropout_op.py
python/paddle/v2/framework/tests/test_dropout_op.py
+42
-0
未找到文件。
paddle/operators/dropout_op.cc
浏览文件 @
b1a18552
...
...
@@ -37,6 +37,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
float
>
(
"dropout_prob"
,
"Dropout probability."
).
SetDefault
(
.5
f
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddInput
(
"X"
,
"The input of dropout op."
);
AddOutput
(
"Out"
,
"The output of dropout op."
);
AddOutput
(
"Mask"
,
"The dropout mask."
).
AsIntermediate
();
...
...
@@ -75,7 +77,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
dropout
,
ops
::
DropoutOp
,
ops
::
DropoutOpMaker
,
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
DropoutKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPU
DropoutKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/dropout_op.cu
浏览文件 @
b1a18552
...
...
@@ -16,7 +16,7 @@
#include "paddle/operators/dropout_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
dropout
,
ops
::
DropoutKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
dropout
,
ops
::
GPU
DropoutKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/dropout_op.h
浏览文件 @
b1a18552
...
...
@@ -13,6 +13,11 @@
limitations under the License. */
#pragma once
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <random>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
...
...
@@ -25,25 +30,85 @@ template <typename T, int MajorType = Eigen::RowMajor,
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
class
DropoutKernel
:
public
framework
::
OpKernel
{
class
CPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
float
dropout_prob
=
context
.
op_
.
GetAttr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
);
std
::
minstd_rand
engine
;
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
(
1
-
dropout_prob
)
*
x_data
[
i
];
}
}
}
};
template
<
typename
T
>
struct
MaskGenerator
{
float
dropout_prob_
;
int
seed_
;
__host__
__device__
MaskGenerator
(
float
dropout_prob
,
int
seed
)
:
dropout_prob_
(
dropout_prob
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
rng
.
discard
(
n
);
if
(
dist
(
rng
)
<
dropout_prob_
)
{
return
static_cast
<
T
>
(
0
);
}
else
{
return
static_cast
<
T
>
(
1
);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
Place
,
typename
T
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
op_
.
GetAttr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int
size
=
framework
::
product
(
mask
->
dims
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
>
(
dropout_prob
,
seed
));
auto
dims
=
x
->
dims
();
auto
X
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
y
);
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
);
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
X
=
EigenMatrix
<
T
>::
From
(
*
x
,
new_dims
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
y
,
new_dims
);
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
M
.
device
(
place
).
setRandom
<
UniformRandomGenerator
>
();
float
dropout_prob
=
context
.
op_
.
GetAttr
<
float
>
(
"dropout_prob"
);
M
.
device
(
place
)
=
(
M
>
dropout_prob
).
cast
<
float
>
();
Y
.
device
(
place
)
=
X
*
Y
;
Y
.
device
(
place
)
=
X
*
M
*
(
1
-
dropout_prob
);
}
};
...
...
@@ -57,12 +122,15 @@ class DropoutGradKernel : public framework::OpKernel {
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dims
=
grad_x
->
dims
();
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
);
auto
dX
=
EigenMatrix
<
T
>::
From
(
*
grad_x
);
auto
dY
=
EigenMatrix
<
T
>::
From
(
*
grad_y
);
int
size
=
static_cast
<
int
>
(
framework
::
product
(
dims
));
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
dX
=
EigenMatrix
<
T
>::
From
(
*
grad_x
,
new_dims
);
auto
dY
=
EigenMatrix
<
T
>::
From
(
*
grad_y
,
new_dims
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
dX
.
device
(
place
)
=
dY
*
M
;
float
dropout_prob
=
context
.
op_
.
GetAttr
<
float
>
(
"dropout_prob"
);
dX
.
device
(
place
)
=
dY
*
M
*
(
1
-
dropout_prob
);
}
};
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
b1a18552
...
...
@@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py)
py_test
(
test_tensor SRCS test_tensor.py
)
py_test
(
test_mul_op SRCS test_mul_op.py
)
py_test
(
test_dropout_op SRCS test_dropout_op.py
)
py_test
(
test_mean_op SRCS test_mean_op.py
)
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
b1a18552
python/paddle/v2/framework/tests/test_dropout_op.py
0 → 100644
浏览文件 @
b1a18552
import
unittest
import
numpy
as
np
from
gradient_checker
import
GradientChecker
,
create_op
from
op_test_util
import
OpTestMeta
class
TestDropoutOpProbZero
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
))}
class
TestDropoutOpAllProbOne
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)),
'Mask'
:
np
.
zeros
((
32
,
64
))}
class
DropoutGradOpTest
(
GradientChecker
):
def
test_dropout_2d
(
self
):
op
=
create_op
(
"dropout"
)
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
5
)).
astype
(
"float32"
)}
self
.
compare_grad
(
op
,
inputs
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
]),
"Out"
)
def
test_dropout_3d
(
self
):
op
=
create_op
(
"dropout"
)
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
5
,
4
)).
astype
(
"float32"
)}
self
.
compare_grad
(
op
,
inputs
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
]),
"Out"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录