Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
32645b52
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32645b52
编写于
9月 16, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move dropout gpu kernel to dropout_op.cu.
上级
05326629
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
64 addition
and
58 deletion
+64
-58
paddle/operators/dropout_op.cu
paddle/operators/dropout_op.cu
+64
-0
paddle/operators/dropout_op.h
paddle/operators/dropout_op.h
+0
-58
未找到文件。
paddle/operators/dropout_op.cu
浏览文件 @
32645b52
...
...
@@ -13,8 +13,72 @@
limitations under the License. */
#define EIGEN_USE_GPU
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/dropout_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
MaskGenerator
{
float
dropout_prob
;
int
seed
;
__host__
__device__
MaskGenerator
(
float
dropout_prob
,
int
seed
)
:
dropout_prob
(
dropout_prob
),
seed
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
rng
.
discard
(
n
);
if
(
dist
(
rng
)
<
dropout_prob
)
{
return
static_cast
<
T
>
(
0
);
}
else
{
return
static_cast
<
T
>
(
1
);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
Place
,
typename
T
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int
size
=
framework
::
product
(
mask
->
dims
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
>
(
dropout_prob
,
seed
));
auto
dims
=
x
->
dims
();
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
X
=
EigenMatrix
<
T
>::
From
(
*
x
,
new_dims
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
y
,
new_dims
);
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
Y
.
device
(
place
)
=
X
*
M
;
// TODO(xinghai-sun): add test time logits.
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
...
...
paddle/operators/dropout_op.h
浏览文件 @
32645b52
...
...
@@ -13,10 +13,6 @@
limitations under the License. */
#pragma once
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <random>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
...
...
@@ -60,60 +56,6 @@ class CPUDropoutKernel : public framework::OpKernel {
}
};
template
<
typename
T
>
struct
MaskGenerator
{
float
dropout_prob
;
int
seed
;
__host__
__device__
MaskGenerator
(
float
dropout_prob
,
int
seed
)
:
dropout_prob
(
dropout_prob
),
seed
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0
,
1
);
rng
.
discard
(
n
);
if
(
dist
(
rng
)
<
dropout_prob
)
{
return
static_cast
<
T
>
(
0
);
}
else
{
return
static_cast
<
T
>
(
1
);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
Place
,
typename
T
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
int
seed
=
context
.
Attr
<
int
>
(
"seed"
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int
size
=
framework
::
product
(
mask
->
dims
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
mask_data
),
MaskGenerator
<
T
>
(
dropout_prob
,
seed
));
auto
dims
=
x
->
dims
();
auto
new_dims
=
framework
::
make_ddim
({
dims
[
0
],
size
/
dims
[
0
]});
auto
X
=
EigenMatrix
<
T
>::
From
(
*
x
,
new_dims
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
y
,
new_dims
);
auto
M
=
EigenMatrix
<
T
>::
From
(
*
mask
,
new_dims
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
Y
.
device
(
place
)
=
X
*
M
;
// TODO: add test time logits.
}
};
template
<
typename
Place
,
typename
T
>
class
DropoutGradKernel
:
public
framework
::
OpKernel
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录