Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7d303bdc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d303bdc
编写于
1月 30, 2018
作者:
C
caoying03
提交者:
ying
1月 30, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the bug that dropout always use a fixed seed.
上级
308f6022
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
58 addition
and
10 deletion
+58
-10
paddle/operators/dropout_op.cc
paddle/operators/dropout_op.cc
+7
-0
paddle/operators/dropout_op.cu
paddle/operators/dropout_op.cu
+5
-1
paddle/operators/dropout_op.h
paddle/operators/dropout_op.h
+7
-1
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+35
-4
python/paddle/v2/fluid/tests/test_dropout_op.py
python/paddle/v2/fluid/tests/test_dropout_op.py
+4
-4
未找到文件。
paddle/operators/dropout_op.cc
浏览文件 @
7d303bdc
...
...
@@ -51,6 +51,13 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"'dropout_prob' must be between 0.0 and 1.0."
);
});
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"fix_seed"
,
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
"training. Setting this flag to true is only useful in "
"unittest or for debug that always the same output units "
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
...
...
paddle/operators/dropout_op.cu
浏览文件 @
7d303bdc
...
...
@@ -62,7 +62,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
size
=
framework
::
product
(
mask
->
dims
());
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
std
::
random_device
rnd
;
int
seed
=
context
.
Attr
<
bool
>
(
"fix_seed"
)
?
context
.
Attr
<
int
>
(
"seed"
)
:
rnd
();
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
...
...
paddle/operators/dropout_op.h
浏览文件 @
7d303bdc
...
...
@@ -38,9 +38,15 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std
::
random_device
rnd
;
std
::
minstd_rand
engine
;
int
seed
=
context
.
Attr
<
bool
>
(
"fix_seed"
)
?
context
.
Attr
<
int
>
(
"seed"
)
:
rnd
();
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
7d303bdc
...
...
@@ -847,7 +847,35 @@ def cos_sim(X, Y, **kwargs):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
0
,
**
kwargs
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
**
kwargs
):
"""
Computes dropout.
Drop or keep each element of `x` independently. Dropout is a regularization
technique for reducing overfitting by preventing neuron co-adaption during
training. The dropout operator randomly set (according to the given dropout
probability) the outputs of some units to zero, while others are remain
unchanged.
Args:
x(variable): The input tensor.
dropout_prob(float): Probability of setting units to zero.
is_test(bool): A flag indicating whether it is in test phrase or not.
seed(int): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used.
NOTE: If an integer seed is given, always the same output
units will be dropped. DO NOT use a fixed seed in training.
Returns:
Variable: A tensor variable.
Examples:
.. code-block:: python
x = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
droped = fluid.layers.dropout(input=x, dropout_rate=0.5)
"""
helper
=
LayerHelper
(
'dropout'
,
**
kwargs
)
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
mask
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
,
stop_gradient
=
True
)
...
...
@@ -856,9 +884,12 @@ def dropout(x, dropout_prob, is_test=False, seed=0, **kwargs):
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
],
'Mask'
:
[
mask
]},
attrs
=
{
'dropout_prob'
:
dropout_prob
,
attrs
=
{
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'seed'
:
seed
})
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
})
return
out
...
...
python/paddle/v2/fluid/tests/test_dropout_op.py
浏览文件 @
7d303bdc
...
...
@@ -21,7 +21,7 @@ class TestDropoutOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'is_test'
:
False
}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'
fix_seed'
:
True
,
'
is_test'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
)).
astype
(
'float32'
)
...
...
@@ -38,7 +38,7 @@ class TestDropoutOp2(TestDropoutOp):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'is_test'
:
False
}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'
fix_seed'
:
True
,
'
is_test'
:
False
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
...
...
@@ -49,7 +49,7 @@ class TestDropoutOp3(TestDropoutOp):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'is_test'
:
False
}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'
fix_seed'
:
True
,
'
is_test'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
...
...
@@ -60,7 +60,7 @@ class TestDropoutOp4(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'is_test'
:
True
}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'
fix_seed'
:
True
,
'
is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
(
1.0
-
self
.
attrs
[
'dropout_prob'
])
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录