Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
585d12a3
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
585d12a3
编写于
9月 19, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add is_training attr and testing phrase compuation to dropout operator.
Change type of dropout_prob to template typename.
上级
32645b52
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
103 addition
and
56 deletion
+103
-56
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+17
-6
paddle/operators/dropout_op.cu
paddle/operators/dropout_op.cu
+24
-22
paddle/operators/dropout_op.h
paddle/operators/dropout_op.h
+31
-25
python/paddle/v2/framework/tests/test_dropout_op.py
python/paddle/v2/framework/tests/test_dropout_op.py
+31
-3
未找到文件。
paddle/operators/dropout_op.cc
浏览文件 @
585d12a3
...
...
@@ -30,6 +30,10 @@ class DropoutOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
.
Attr
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
.
Attr
<
float
>
(
"dropout_prob"
),
1
);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE
(
ctx
.
Attr
<
int
>
(
"is_training"
)
==
0
||
ctx
.
Attr
<
int
>
(
"is_training"
)
==
1
);
// resize
auto
dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
ctx
.
Output
<
LoDTensor
>
(
"Out"
)
->
Resize
(
dims
);
...
...
@@ -37,13 +41,16 @@ class DropoutOp : public framework::OperatorWithKernel {
}
};
template
<
typename
AttrType
>
class
DropoutOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
DropoutOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
float
>
(
"dropout_prob"
,
"Probability for dropping out units
."
)
AddAttr
<
AttrType
>
(
"dropout_prob"
,
"Probability of setting units to zero
."
)
.
SetDefault
(
.5
f
);
// TODO(xinghai-sun): use bool for is_training after bool is supported.
AddAttr
<
int
>
(
"is_training"
,
"Whether in training phase."
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddInput
(
"X"
,
"The input of dropout op."
);
AddOutput
(
"Out"
,
"The output of dropout op."
);
...
...
@@ -61,6 +68,7 @@ being set to their inputs.
}
};
template
<
typename
AttrType
>
class
DropoutOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -72,8 +80,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Mask"
),
"Mask must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) must not be null."
);
PADDLE_ENFORCE_GE
(
ctx
.
Attr
<
float
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
.
Attr
<
float
>
(
"dropout_prob"
),
1
);
PADDLE_ENFORCE_GE
(
ctx
.
Attr
<
AttrType
>
(
"dropout_prob"
),
0
);
PADDLE_ENFORCE_LE
(
ctx
.
Attr
<
AttrType
>
(
"dropout_prob"
),
1
);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE
(
ctx
.
Attr
<
int
>
(
"is_training"
)
==
0
||
ctx
.
Attr
<
int
>
(
"is_training"
)
==
1
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
mask_dims
=
ctx
.
Input
<
Tensor
>
(
"Mask"
)
->
dims
();
auto
out_dims
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
dims
();
...
...
@@ -91,9 +102,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
dropout
,
ops
::
DropoutOp
,
ops
::
DropoutOpMaker
,
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP
(
dropout
,
ops
::
DropoutOp
,
ops
::
DropoutOpMaker
<
float
>
,
dropout_grad
,
ops
::
DropoutOpGrad
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUPlace
,
float
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/dropout_op.cu
浏览文件 @
585d12a3
...
...
@@ -22,18 +22,18 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
,
typename
AttrType
>
struct
MaskGenerator
{
float
dropout_prob
;
AttrType
dropout_prob
;
int
seed
;
__host__
__device__
MaskGenerator
(
float
dropout_prob
,
int
seed
)
__host__
__device__
MaskGenerator
(
AttrType
dropout_prob
,
int
seed
)
:
dropout_prob
(
dropout_prob
),
seed
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
thrust
::
uniform_real_distribution
<
AttrType
>
dist
(
0
,
1
);
rng
.
discard
(
n
);
if
(
dist
(
rng
)
<
dropout_prob
)
{
return
static_cast
<
T
>
(
0
);
...
...
@@ -46,33 +46,35 @@ struct MaskGenerator {
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
,
typename
AttrType
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int
size
=
framework
::
product
(
mask
->
dims
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
>
(
dropout_prob
,
seed
));
AttrType
dropout_prob
=
context
.
Attr
<
AttrType
>
(
"dropout_prob"
);
auto
dims
=
x
->
dims
();
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
X
=
EigenMatrix
<
T
>::
From
(
*
x
,
new_dims
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
y
,
new_dims
);
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
M
=
EigenMatrix
<
T
>::
Reshape
(
*
mask
,
1
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
Y
.
device
(
place
)
=
X
*
M
;
// TODO(xinghai-sun): add test time logits.
int
size
=
framework
::
product
(
mask
->
dims
());
if
(
context
.
Attr
<
int
>
(
"is_training"
)
==
1
)
{
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
,
AttrType
>
(
dropout_prob
,
seed
));
Y
.
device
(
place
)
=
X
*
M
;
}
else
{
cudaMemset
(
mask_data
,
0
,
sizeof
(
T
)
*
size
);
Y
.
device
(
place
)
=
X
*
dropout_prob
;
}
}
};
...
...
@@ -81,6 +83,6 @@ class GPUDropoutKernel : public framework::OpKernel {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
dropout
,
ops
::
GPUDropoutKernel
<
paddle
::
platform
::
GPUPlace
,
float
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/dropout_op.h
浏览文件 @
585d12a3
...
...
@@ -25,34 +25,42 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
,
typename
AttrType
>
class
CPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
auto
*
x_data
=
x
->
data
<
T
>
();
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
AttrType
dropout_prob
=
context
.
Attr
<
AttrType
>
(
"dropout_prob"
);
std
::
minstd_rand
engine
;
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
context
.
Attr
<
int
>
(
"is_training"
)
==
1
)
{
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
std
::
minstd_rand
engine
;
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
AttrType
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
else
{
size_t
size
=
framework
::
product
(
mask
->
dims
());
memset
(
mask_data
,
0
,
sizeof
(
T
)
*
size
);
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
Y
.
device
(
place
)
=
X
*
dropout_prob
;
}
// TODO: add test phase logits.
}
};
...
...
@@ -60,21 +68,19 @@ template <typename Place, typename T>
class
DropoutGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
context
.
Attr
<
int
>
(
"is_training"
),
1
,
"Only callable when is_training is true"
);
auto
*
grad_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dims
=
grad_x
->
dims
();
int
size
=
static_cast
<
int
>
(
framework
::
product
(
dims
));
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
dX
=
EigenMatrix
<
T
>::
From
(
*
grad_x
,
new_dims
);
auto
dY
=
EigenMatrix
<
T
>::
From
(
*
grad_y
,
new_dims
);
auto
M
=
EigenMatrix
<
T
>::
Reshape
(
*
mask
,
1
);
auto
dX
=
EigenMatrix
<
T
>::
Reshape
(
*
grad_x
,
1
);
auto
dY
=
EigenMatrix
<
T
>::
Reshape
(
*
grad_y
,
1
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
dX
.
device
(
place
)
=
dY
*
M
;
// TODO: add test time logits.
}
};
...
...
python/paddle/v2/framework/tests/test_dropout_op.py
浏览文件 @
585d12a3
...
...
@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'is_training'
:
1
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
))}
def
test_check_output
(
self
):
...
...
@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'is_training'
:
1
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)),
'Mask'
:
np
.
zeros
((
32
,
64
))}
...
...
@@ -29,9 +29,37 @@ class TestDropoutOp3(TestDropoutOp):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'is_training'
:
1
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
))}
class
TestDropoutOp4
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'is_training'
:
0
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
attrs
[
'dropout_prob'
],
'Mask'
:
np
.
zeros
((
32
,
64
))
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp5
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_training'
:
0
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
attrs
[
'dropout_prob'
],
'Mask'
:
np
.
zeros
((
32
,
64
,
3
))
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录