Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1c01d1cc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
1c01d1cc
编写于
3月 25, 2022
作者:
zhouweiwei2014
提交者:
GitHub
3月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change CUDA implementation of dropout OP (#40874)
上级
236a3bc5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
183 addition
and
46 deletion
+183
-46
paddle/fluid/operators/dropout_impl.cu.h
paddle/fluid/operators/dropout_impl.cu.h
+126
-46
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+57
-0
未找到文件。
paddle/fluid/operators/dropout_impl.cu.h
浏览文件 @
1c01d1cc
...
...
@@ -37,8 +37,12 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
DECLARE_bool
(
use_curand
);
namespace
paddle
{
namespace
operators
{
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
DstMaskGenerator
{
const
float
dropout_prob_
;
...
...
@@ -71,13 +75,45 @@ struct DstMaskGenerator {
}
};
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
DstMaskFunctor
{
const
float
retain_prob_
;
const
bool
is_upscale_in_train_
;
using
MT
=
typename
details
::
MPTypeTrait
<
T1
>::
Type
;
MT
factor
;
HOSTDEVICE
inline
DstMaskFunctor
(
const
float
retain_prob
,
const
bool
is_upscale_in_train
)
:
retain_prob_
(
retain_prob
),
is_upscale_in_train_
(
is_upscale_in_train
)
{
factor
=
static_cast
<
MT
>
(
1.0
f
/
retain_prob_
);
}
HOSTDEVICE
inline
void
operator
()(
OutT
*
dst
,
const
T1
*
src_val
,
const
T2
*
rand
,
int
num
)
const
{
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
T2
>::
kReturnsCount
;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for
(
int
i
=
0
;
i
<
kCount
;
i
++
)
{
if
(
rand
[
i
]
<
retain_prob_
)
{
dst
[
i
]
=
is_upscale_in_train_
?
static_cast
<
T1
>
(
static_cast
<
MT
>
(
src_val
[
i
])
*
factor
)
:
static_cast
<
T1
>
(
src_val
[
i
]);
dst
[
i
+
kCount
]
=
static_cast
<
T1
>
(
1
);
}
else
{
dst
[
i
]
=
static_cast
<
T1
>
(
0
);
dst
[
i
+
kCount
]
=
dst
[
i
];
}
}
}
};
template
<
typename
T
,
typename
MaskType
>
__global__
void
VectorizedRandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
,
size_t
main_offset
)
{
size_t
main_offset
,
bool
use_curand
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
...
...
@@ -97,37 +133,78 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
using
Cast
=
kps
::
IdentityFunctor
<
T
>
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
auto
dst_functor
=
DstMaskGenerator
<
T
,
float
>
(
dropout_prob
,
is_upscale_in_train
);
size_t
fix
=
idx
*
kCount
;
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
if
(
use_curand
)
{
auto
dst_functor
=
DstMaskFunctor
<
T
,
float
>
(
1.0
f
-
dropout_prob
,
is_upscale_in_train
);
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
}
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
__syncthreads
();
}
}
else
{
auto
dst_functor
=
DstMaskGenerator
<
T
,
float
>
(
dropout_prob
,
is_upscale_in_train
);
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
}
}
}
...
...
@@ -164,31 +241,34 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
return
;
}
// increment is used to set the args(offset) of curand_init, which defines
// offset in subsequence.
// The detail:
// https://docs.nvidia.com/cuda/curand/device-api-overview.html
// Increment should be at least the number of curand() random numbers used
// in each thread to avoid the random number generated this time being the
// same as the previous calls.
uint64_t
seed_data
;
uint64_t
increment
;
// VectorizedRandomGenerator use curand_uniform4, so we only support
// kVecSize is 4;
// VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
constexpr
int
kVecSize
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x_numel
,
kVecSize
);
size_t
grid_size
=
gpu_config
.
GetGridSize
();
size_t
block_size
=
gpu_config
.
GetBlockSize
();
if
(
FLAGS_use_curand
)
{
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
platform
::
GetDeviceProperties
(
device_id
);
size_t
max_grid_size
=
prop
.
maxThreadsPerMultiProcessor
*
prop
.
multiProcessorCount
/
block_size
;
grid_size
=
std
::
min
(
grid_size
,
max_grid_size
);
}
auto
offset
=
((
x_numel
-
1
)
/
(
g
pu_config
.
GetThreadNum
()
*
kVecSize
)
+
1
)
*
kVecSize
;
((
x_numel
-
1
)
/
(
g
rid_size
*
block_size
*
kVecSize
)
+
1
)
*
kVecSize
;
GetSeedDataAndIncrement
(
dev_ctx
,
seed
,
is_fix_seed
,
seed_val
,
offset
,
&
seed_data
,
&
increment
);
size_t
main_offset
=
size
/
(
gpu_config
.
GetBlockSize
()
*
kVecSize
)
*
(
gpu_config
.
GetBlockSize
()
*
kVecSize
);
VectorizedRandomGenerator
<
T
,
uint8_t
><<<
gpu_config
.
GetGridSize
(),
gpu_config
.
GetBlockSize
()
,
0
,
stream
>>>
(
size_t
main_offset
=
size
/
(
block_size
*
kVecSize
)
*
(
block_size
*
kVecSize
);
VectorizedRandomGenerator
<
T
,
uint8_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
,
main_offset
);
upscale_in_train
,
increment
,
main_offset
,
FLAGS_use_curand
);
}
else
{
if
(
upscale_in_train
)
{
// todo: can y share with data with x directly?
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
1c01d1cc
...
...
@@ -22,6 +22,7 @@ import paddle
import
paddle.static
as
static
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
import
os
class
TestDropoutOp
(
OpTest
):
...
...
@@ -992,6 +993,62 @@ class TestDropoutBackward(unittest.TestCase):
),
self
.
cal_grad_upscale_train
(
mask
.
numpy
(),
prob
)))
class
TestRandomValue
(
unittest
.
TestCase
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
return
# Different GPU generate different random value. Only test V100 here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
100
)
x
=
paddle
.
rand
([
32
,
1024
,
1024
],
dtype
=
'float32'
)
out
=
paddle
.
nn
.
functional
.
dropout
(
x
,
0.25
).
numpy
()
index0
,
index1
,
index2
=
np
.
nonzero
(
out
)
self
.
assertEqual
(
np
.
sum
(
index0
),
390094540
)
self
.
assertEqual
(
np
.
sum
(
index1
),
12871475125
)
self
.
assertEqual
(
np
.
sum
(
index2
),
12872777397
)
self
.
assertEqual
(
np
.
sum
(
out
),
16778744.0
)
expect
=
[
0.6914956
,
0.5294584
,
0.19032137
,
0.6996228
,
0.3338527
,
0.8442094
,
0.96965003
,
1.1726775
,
0.
,
0.28037727
]
self
.
assertTrue
(
np
.
allclose
(
out
[
10
,
100
,
500
:
510
],
expect
))
x
=
paddle
.
rand
([
32
,
1024
,
1024
],
dtype
=
'float64'
)
out
=
paddle
.
nn
.
functional
.
dropout
(
x
).
numpy
()
index0
,
index1
,
index2
=
np
.
nonzero
(
out
)
self
.
assertEqual
(
np
.
sum
(
index0
),
260065137
)
self
.
assertEqual
(
np
.
sum
(
index1
),
8582636095
)
self
.
assertEqual
(
np
.
sum
(
index2
),
8582219962
)
self
.
assertEqual
(
np
.
sum
(
out
),
16778396.563660286
)
expect
=
[
1.28587354
,
0.15563703
,
0.
,
0.28799703
,
0.
,
0.
,
0.
,
0.54964
,
0.51355682
,
0.33818988
]
self
.
assertTrue
(
np
.
allclose
(
out
[
20
,
100
,
500
:
510
],
expect
))
x
=
paddle
.
ones
([
32
,
1024
,
1024
],
dtype
=
'float16'
)
out
=
paddle
.
nn
.
functional
.
dropout
(
x
,
0.75
).
numpy
()
index0
,
index1
,
index2
=
np
.
nonzero
(
out
)
self
.
assertEqual
(
np
.
sum
(
index0
),
130086900
)
self
.
assertEqual
(
np
.
sum
(
index1
),
4291190105
)
self
.
assertEqual
(
np
.
sum
(
index2
),
4292243807
)
expect
=
[
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
4.
,
4.
]
self
.
assertTrue
(
np
.
allclose
(
out
[
0
,
100
,
500
:
510
],
expect
))
paddle
.
enable_static
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录