Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
95d3ebc8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95d3ebc8
编写于
3月 23, 2022
作者:
N
niuliling123
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modified dropout Kernel with Kernel Primitive API (#40766)
上级
17b8335b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
121 addition
and
157 deletion
+121
-157
paddle/fluid/operators/dropout_impl.cu.h
paddle/fluid/operators/dropout_impl.cu.h
+103
-152
paddle/phi/kernels/funcs/distribution_helper.h
paddle/phi/kernels/funcs/distribution_helper.h
+16
-2
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
+2
-3
未找到文件。
paddle/fluid/operators/dropout_impl.cu.h
浏览文件 @
95d3ebc8
...
...
@@ -35,143 +35,99 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
DstMaskGenerator
{
const
float
dropout_prob_
;
const
bool
is_upscale_in_train_
;
using
MT
=
typename
details
::
MPTypeTrait
<
T1
>::
Type
;
MT
factor
;
HOSTDEVICE
inline
DstMaskGenerator
(
const
float
dropout_prob
,
const
bool
is_upscale_in_train
)
:
dropout_prob_
(
dropout_prob
),
is_upscale_in_train_
(
is_upscale_in_train
)
{
factor
=
static_cast
<
MT
>
(
1.0
f
/
(
1.0
f
-
dropout_prob_
));
}
template
<
typename
T
,
typename
MaskType
>
__global__
void
RandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
increment
,
&
state
);
#else
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
#endif
MaskType
mask_val
;
T
dst_val
;
MT
factor
=
static_cast
<
MT
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
src_val
=
src
[
idx
];
#ifdef PADDLE_WITH_HIP
if
(
hiprand_uniform
(
&
state
)
<
dropout_prob
)
{
#else
if
(
curand_uniform
(
&
state
)
<
dropout_prob
)
{
#endif
mask_val
=
0
;
dst_val
=
0
;
}
else
{
mask_val
=
1
;
dst_val
=
is_upscale_in_train
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_val
)
*
factor
)
:
src_val
;
HOSTDEVICE
inline
void
operator
()(
OutT
*
dst
,
const
T1
*
src_val
,
const
T2
*
rand
,
int
num
)
const
{
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
T2
>::
kReturnsCount
;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for
(
int
i
=
0
;
i
<
kCount
;
i
++
)
{
if
(
rand
[
i
]
<
dropout_prob_
)
{
dst
[
i
]
=
static_cast
<
T1
>
(
0
);
dst
[
i
+
kCount
]
=
dst
[
i
];
}
else
{
dst
[
i
]
=
is_upscale_in_train_
?
static_cast
<
T1
>
(
static_cast
<
MT
>
(
src_val
[
i
])
*
factor
)
:
static_cast
<
T1
>
(
src_val
[
i
]);
dst
[
i
+
kCount
]
=
static_cast
<
T1
>
(
1
);
}
}
mask
[
idx
]
=
mask_val
;
dst
[
idx
]
=
dst_val
;
}
}
}
;
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
template
<
typename
T
,
typename
MaskType
>
__global__
void
VectorizedRandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
uint64_t
increment
,
size_t
main_offset
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
size_t
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
kCount
;
#ifdef PADDLE_WITH_HIP
int64_t
idx
=
hipBlockDim_x
*
hipBlockIdx_x
+
hipThreadIdx_x
;
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
increment
,
&
state
);
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#else
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
#endif
MT
factor
=
static_cast
<
MT
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
for
(
int
i
=
idx
*
VecSize
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
LoadT
src_val
;
phi
::
Load
<
T
,
VecSize
>
(
&
src
[
i
],
&
src_val
);
#ifdef PADDLE_WITH_HIP
float4
rand
=
hiprand_uniform4
(
&
state
);
#else
float4
rand
=
curand_uniform4
(
&
state
);
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#endif
LoadT
dst_val
;
MaskLoadT
mask_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
if
((
&
rand
.
x
)[
j
]
<
dropout_prob
)
{
dst_val
[
j
]
=
0
;
mask_val
[
j
]
=
0
;
}
else
{
dst_val
[
j
]
=
is_upscale_in_train
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_val
[
j
])
*
factor
)
:
src_val
[
j
];
mask_val
[
j
]
=
1
;
}
}
phi
::
Store
<
T
,
VecSize
>
(
dst_val
,
&
dst
[
i
]);
phi
::
Store
<
MaskType
,
VecSize
>
(
mask_val
,
&
mask
[
i
]);
T
dst_mask
[
kCount
*
2
];
// 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
float
rands
[
kCount
];
MaskType
mask_result
[
kCount
];
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
using
Cast
=
kps
::
IdentityFunctor
<
T
>
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
auto
dst_functor
=
DstMaskGenerator
<
T
,
float
>
(
dropout_prob
,
is_upscale_in_train
);
size_t
fix
=
idx
*
kCount
;
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
}
}
template
<
typename
T
,
typename
MaskType
>
struct
CudaDropoutGradFunctor
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
explicit
CudaDropoutGradFunctor
(
const
MT
factor
)
:
factor_
(
factor
)
{}
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
MaskType
mask
)
const
{
return
static_cast
<
T
>
(
static_cast
<
MT
>
(
dout
)
*
static_cast
<
MT
>
(
mask
)
*
factor_
);
}
private:
MT
factor_
;
};
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
DropoutGradCUDAKernel
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
typename
details
::
MPTypeTrait
<
T
>::
Type
factor
,
const
int64_t
size
,
T
*
dx
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
LoadT
dout_val
;
phi
::
Load
<
T
,
VecSize
>
(
&
dout
[
i
],
&
dout_val
);
MaskLoadT
mask_val
;
phi
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
i
],
&
mask_val
);
LoadT
dx_val
;
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
dx_val
[
j
]
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
dout_val
[
j
])
*
static_cast
<
MT
>
(
mask_val
[
j
])
*
factor
);
}
phi
::
Store
<
T
,
VecSize
>
(
dx_val
,
&
dx
[
i
]);
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
}
}
...
...
@@ -218,42 +174,21 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
uint64_t
seed_data
;
uint64_t
increment
;
// VectorizedRandomGenerator use curand_uniform4, so we only support
// vec_size is 4;
int
vec_size
=
(
phi
::
GetVectorizedSize
<
T
>
(
x_data
)
==
4
)
?
4
:
1
;
// kVecSize is 4;
constexpr
int
kVecSize
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x_numel
,
vec_s
ize
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x_numel
,
kVecS
ize
);
auto
offset
=
((
x_numel
-
1
)
/
(
gpu_config
.
GetThreadNum
()
*
vec_size
)
+
1
)
*
vec_size
;
((
x_numel
-
1
)
/
(
gpu_config
.
GetThreadNum
()
*
kVecSize
)
+
1
)
*
kVecSize
;
GetSeedDataAndIncrement
(
dev_ctx
,
seed
,
is_fix_seed
,
seed_val
,
offset
,
&
seed_data
,
&
increment
);
#ifdef __HIPCC__
if
(
vec_size
==
4
&&
size
%
4
==
0
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
VectorizedRandomGenerator
<
T
,
uint8_t
,
4
>
),
gpu_config
.
GetGridSize
(),
gpu_config
.
GetBlockSize
(),
0
,
stream
,
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
RandomGenerator
<
T
,
uint8_t
>
),
gpu_config
.
GetGridSize
(),
gpu_config
.
GetBlockSize
(),
0
,
stream
,
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
#else
if
(
vec_size
==
4
&&
size
%
4
==
0
)
{
VectorizedRandomGenerator
<
T
,
uint8_t
,
4
><<<
gpu_config
.
block_per_grid
,
gpu_config
.
thread_per_block
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
else
{
RandomGenerator
<
T
,
uint8_t
><<<
gpu_config
.
block_per_grid
,
gpu_config
.
thread_per_block
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
);
}
#endif
size_t
main_offset
=
size
/
(
gpu_config
.
GetBlockSize
()
*
kVecSize
)
*
(
gpu_config
.
GetBlockSize
()
*
kVecSize
);
VectorizedRandomGenerator
<
T
,
uint8_t
><<<
gpu_config
.
GetGridSize
(),
gpu_config
.
GetBlockSize
(),
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
,
main_offset
);
}
else
{
if
(
upscale_in_train
)
{
// todo: can y share with data with x directly?
...
...
@@ -278,6 +213,22 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
}
}
template
<
typename
T
,
typename
MaskType
>
struct
CudaDropoutGradFunctor
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
explicit
CudaDropoutGradFunctor
(
const
MT
factor
)
:
factor_
(
factor
)
{}
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
MaskType
mask
)
const
{
return
static_cast
<
T
>
(
static_cast
<
MT
>
(
dout
)
*
static_cast
<
MT
>
(
mask
)
*
factor_
);
}
private:
MT
factor_
;
};
template
<
typename
T
>
void
DropoutGradGPUKernelDriver
(
const
phi
::
GPUContext
&
dev_ctx
,
const
std
::
string
dropout_implementation
,
...
...
paddle/phi/kernels/funcs/distribution_helper.h
浏览文件 @
95d3ebc8
...
...
@@ -114,13 +114,19 @@ struct normal_transform {
namespace
kps
=
phi
::
kps
;
/*********************** Distribution Function *************************/
template
<
typename
T
>
struct
uniform_distribution
;
template
<
typename
T
>
struct
normal_distribution
;
#if defined(__NVCC__)
template
<
typename
T
>
struct
uniform_distribution
{
__device__
inline
T
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
return
static_cast
<
T
>
(
curand_uniform
(
state
));
}
static
constexpr
int
kReturnsCount
=
1
;
};
template
<
>
struct
uniform_distribution
<
float
>
{
__device__
inline
float4
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
...
...
@@ -177,6 +183,14 @@ struct normal_distribution<double> {
};
#else
template
<
typename
T
>
struct
uniform_distribution
{
__device__
inline
T
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_uniform
(
state
);
}
static
constexpr
int
kReturnsCount
=
1
;
};
template
<
>
struct
uniform_distribution
<
float
>
{
__device__
inline
float4
operator
()(
...
...
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
浏览文件 @
95d3ebc8
...
...
@@ -17,11 +17,10 @@
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
MT
,
typename
InT
,
typename
OutT
>
...
...
@@ -50,7 +49,7 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
const
DenseTensor
&
mask
,
DenseTensor
*
x_grad
)
{
auto
mask_size
=
mask
.
numel
();
auto
*
out_data
=
x_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
()
);
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
if
(
mask_size
<=
0
)
return
;
using
Functor
=
MaskedSelectGradFunctor
<
bool
,
T
,
T
>
;
phi
::
funcs
::
SelectKernel
<
bool
,
T
,
T
,
2
,
Functor
>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录