Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1e9127f6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1e9127f6
编写于
12月 16, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
12月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve dropout grad (#29605)
* improve grad perf
上级
eab44e1f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
77 addition
and
29 deletion
+77
-29
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+12
-26
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+65
-3
未找到文件。
paddle/fluid/operators/dropout_op.cu
浏览文件 @
1e9127f6
...
...
@@ -27,22 +27,6 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
template
<
typename
T
,
typename
MaskType
>
__global__
void
RandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
...
...
@@ -154,12 +138,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
return
;
}
int
threads
=
512
;
int
grid
=
(
x_numel
+
threads
-
1
)
/
threads
;
const
auto
&
dev_ctx
=
context
.
cuda_device_context
();
int
blocks_per_sm
=
dev_ctx
.
GetMaxPhysicalThreadCount
()
/
dev_ctx
.
GetSMCount
()
/
threads
;
grid
=
std
::
min
(
dev_ctx
.
GetSMCount
()
*
blocks_per_sm
,
grid
);
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
size
);
// increment is used to set the args(offset) of curand_init, which defines
// offset in subsequence.
...
...
@@ -171,8 +152,10 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
uint64_t
seed_data
;
uint64_t
increment
;
int
vec_size
=
VectorizedSize
<
T
>
(
x_data
);
auto
offset
=
((
x_numel
-
1
)
/
(
threads
*
grid
*
vec_size
)
+
1
)
*
vec_size
;
auto
offset
=
((
x_numel
-
1
)
/
(
config
.
block_per_grid
.
x
*
config
.
thread_per_block
.
x
*
vec_size
)
+
1
)
*
vec_size
;
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
())
.
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
...
...
@@ -197,12 +180,15 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
increment
=
offset
;
}
if
(
vec_size
==
4
)
{
VectorizedRandomGenerator
<
T
,
uint8_t
,
4
><<<
grid
,
threads
,
0
,
stream
>>>
(
if
(
vec_size
==
4
&&
size
%
4
==
0
)
{
VectorizedRandomGenerator
<
T
,
uint8_t
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
else
{
RandomGenerator
<
T
,
uint8_t
><<<
grid
,
threads
,
0
,
stream
>>>
(
RandomGenerator
<
T
,
uint8_t
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
...
...
paddle/fluid/operators/dropout_op.h
浏览文件 @
1e9127f6
...
...
@@ -17,13 +17,59 @@ limitations under the License. */
#include <random>
#include <string>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
operators
{
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
#ifdef __NVCC__
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
DropoutGradCUDAKernel
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
factor
,
const
int64_t
size
,
T
*
dx
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
AlignedVector
<
MaskType
,
VecSize
>
;
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
T
dout_vec
[
VecSize
];
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
dout_vec
);
*
value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
dout
[
i
]);
T
dx_vec
[
VecSize
];
MaskType
mask_vec
[
VecSize
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
dx_vec
[
ii
]
=
dout_vec
[
ii
]
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
;
}
*
(
reinterpret_cast
<
LoadT
*>
(
&
dx
[
i
]))
=
*
reinterpret_cast
<
LoadT
*>
(
&
dx_vec
[
0
]);
}
}
#endif
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
...
...
@@ -119,6 +165,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
size
=
grad_x
->
numel
();
auto
M
=
EigenVector
<
uint8_t
>::
Flatten
(
*
mask
);
auto
dX
=
EigenVector
<
T
>::
Flatten
(
*
grad_x
);
...
...
@@ -126,7 +173,6 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
...
...
@@ -134,8 +180,24 @@ class DropoutGradKernel : public framework::OpKernel<T> {
if
(
dropout_prob
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
int
vec_size
=
VectorizedSize
<
T
>
(
grad_y
->
data
<
T
>
());
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
())
&&
vec_size
==
4
&&
size
%
4
==
0
)
{
#ifdef __NVCC__
auto
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
auto
stream
=
context
.
cuda_device_context
().
stream
();
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
context
.
cuda_device_context
(),
size
);
DropoutGradCUDAKernel
<
T
,
uint8_t
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
grad_y
->
data
<
T
>
(),
mask
->
data
<
uint8_t
>
(),
factor
,
size
,
grad_x
->
data
<
T
>
());
#endif
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录