Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
07923ba0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
07923ba0
编写于
2月 12, 2018
作者:
D
dzhwinter
提交者:
GitHub
2月 12, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Memory/dropout4 (#8407)
* "merge random generator kernel and mul" * "fix dropout"
上级
91aac572
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
21 addition
and
21 deletion
+21
-21
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+21
-21
未找到文件。
paddle/fluid/operators/dropout_op.cu
浏览文件 @
07923ba0
...
...
@@ -23,24 +23,23 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
,
typename
AttrType
>
struct
MaskGenerator
{
AttrType
dropout_prob
;
int
seed
;
__host__
__device__
MaskGenerator
(
AttrType
dropout_prob
,
int
seed
)
:
dropout_prob
(
dropout_prob
),
seed
(
seed
)
{}
inline
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
AttrType
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
AttrType
>
dist
(
0
,
1
);
rng
.
discard
(
n
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
dist
(
rng
)
<
dropout_prob
)
{
return
static_cast
<
T
>
(
0
);
mask_data
[
idx
]
=
static_cast
<
T
>
(
0
);
}
else
{
mask_data
[
idx
]
=
static_cast
<
T
>
(
1
);
}
return
static_cast
<
T
>
(
1
)
;
dst
[
idx
]
=
mask_data
[
idx
]
*
src
[
idx
]
;
}
}
;
}
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
...
...
@@ -61,18 +60,19 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
size
=
framework
::
product
(
mask
->
dims
());
size_t
size
=
framework
::
product
(
mask
->
dims
());
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
random_device
rnd
;
int
seed
=
context
.
Attr
<
bool
>
(
"fix_seed"
)
?
context
.
Attr
<
int
>
(
"seed"
)
:
rnd
();
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
,
AttrType
>
(
dropout_prob
,
seed
));
auto
M
=
EigenMatrix
<
T
>::
Reshape
(
*
mask
,
1
);
Y
.
device
(
place
)
=
X
*
M
;
int
threads
=
512
;
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
,
AttrType
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
}
else
{
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录