Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
34d4b40d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
34d4b40d
编写于
3月 11, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify the softmax kernel and add the check of whether cudnn softmax can be used. (#40424)
上级
f452ad5c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
165 addition
and
179 deletion
+165
-179
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
+165
-179
未找到文件。
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
浏览文件 @
34d4b40d
...
...
@@ -79,7 +79,7 @@ class VecT2<phi::dtype::bfloat16> {
using
Type
=
int
;
};
static
inline
int
log2_c
eil
(
int
value
)
{
static
inline
int
Log2C
eil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
...
...
@@ -577,8 +577,8 @@ static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
#else
constexpr
int
max_num_threads
=
1024
;
#endif
int
block_x
=
1
<<
log2_c
eil
(
low_dim
);
int
block_y
=
1
<<
log2_c
eil
(
mid_dim
);
int
block_x
=
1
<<
Log2C
eil
(
low_dim
);
int
block_y
=
1
<<
Log2C
eil
(
mid_dim
);
block
->
x
=
std
::
min
(
block_x
,
32
);
block
->
y
=
std
::
min
(
block_y
,
static_cast
<
int
>
(
max_num_threads
/
block
->
x
));
block
->
x
=
std
::
min
(
block_x
,
static_cast
<
int
>
(
max_num_threads
/
block
->
y
));
...
...
@@ -739,6 +739,131 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
}
}
static
std
::
vector
<
int
>
GetSoftmaxTensorDims
(
const
phi
::
DDim
&
dims
,
const
int
axis
)
{
int
dim
=
dims
[
axis
];
int
N
=
phi
::
funcs
::
SizeToAxis
(
axis
,
dims
);
int
D
=
phi
::
funcs
::
SizeOutAxis
(
axis
,
dims
);
return
{
N
,
dim
,
D
,
1
};
}
template
<
typename
T
>
void
SoftmaxForwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
int
axis
,
const
bool
log_mode
,
DenseTensor
*
out
)
{
auto
*
out_data
=
out
->
data
<
T
>
();
const
int
rank
=
x
.
dims
().
size
();
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
x
.
dims
(),
axis
);
auto
handle
=
dev_ctx
.
cudnn_handle
();
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
ScopedTensorDescriptor
scoped_desc
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc
=
scoped_desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
auto
algo
=
log_mode
?
MIOPEN_SOFTMAX_LOG
:
MIOPEN_SOFTMAX_ACCURATE
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
out_data
,
algo
,
mode
));
#else
cudnnTensorDescriptor_t
desc
=
scoped_desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
auto
algo
=
log_mode
?
CUDNN_SOFTMAX_LOG
:
CUDNN_SOFTMAX_ACCURATE
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
algo
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
out_data
));
#endif
}
template
<
typename
T
>
void
SoftmaxBackwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
const
int
axis
,
const
bool
log_mode
,
DenseTensor
*
dx
)
{
auto
*
dx_data
=
dx
->
data
<
T
>
();
int
rank
=
out
.
dims
().
size
();
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
out
.
dims
(),
axis
);
auto
handle
=
dev_ctx
.
cudnn_handle
();
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
ScopedTensorDescriptor
scoped_desc
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc
=
scoped_desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
auto
algo
=
log_mode
?
MIOPEN_SOFTMAX_LOG
:
MIOPEN_SOFTMAX_ACCURATE
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
out
.
data
<
T
>
(),
desc
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
dx_data
,
algo
,
mode
));
#else
cudnnTensorDescriptor_t
desc
=
scoped_desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
auto
algo
=
log_mode
?
CUDNN_SOFTMAX_LOG
:
CUDNN_SOFTMAX_ACCURATE
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
algo
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
out
.
data
<
T
>
(),
desc
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
dx_data
));
#endif
}
template
<
typename
T
>
static
bool
CanUseCudnnSoftmax
(
const
GPUContext
&
dev_ctx
)
{
if
(
dev_ctx
.
cudnn_handle
()
!=
nullptr
)
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
#if CUDNN_VERSION < 8100
return
false
;
#endif
}
return
true
;
}
return
false
;
}
template
<
typename
T
,
bool
LogMode
=
false
>
void
SoftmaxForwardCUDAKernelDriver
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
@@ -746,29 +871,29 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
DenseTensor
*
out
)
{
auto
*
out_data
=
out
->
data
<
T
>
();
auto
dims
=
x
.
dims
();
const
int
rank
=
dims
.
size
(
);
const
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
input_axis
,
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
phi
::
funcs
::
SizeToAxis
(
axis
,
dims
)
;
const
int
D
=
phi
::
funcs
::
SizeOutAxis
(
axis
,
dims
)
;
int
rank
=
x
.
dims
().
size
();
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
input_axis
,
rank
);
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
x
.
dims
(),
axis
);
int
N
=
tensor_dims
[
0
];
int
dim
=
tensor_dims
[
1
]
;
int
D
=
tensor_dims
[
2
]
;
constexpr
int
max_dim
=
512
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
32
)
?
2
:
1
;
if
(
D
==
1
&&
(
!
CanUseCudnnSoftmax
<
T
>
(
dev_ctx
)
||
(
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)))
{
int
dim_log2
=
static_cast
<
int
>
(
Log2Ceil
(
dim
));
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
int
batches_per_warp
=
(
dim_ceil
<=
32
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpS
ize
);
int
warps_per_block
=
(
threads_per_block
/
warp_s
ize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpS
ize
,
warps_per_block
,
1
);
dim3
threads
(
warp_s
ize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
...
...
@@ -783,7 +908,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
...
...
@@ -793,7 +918,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
else
{
SwitchWarpSoftmaxForward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
...
...
@@ -803,78 +928,13 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
}
else
if
(
D
>
1
)
{
LaunchNormalSoftmaxForward
<
T
,
LogMode
>
(
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
D
);
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxForward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxForward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
x
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
out_data
));
}
#endif
SoftmaxForwardCudnnKernel
<
T
>
(
dev_ctx
,
x
,
axis
,
LogMode
,
out
);
}
}
...
...
@@ -886,27 +946,28 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
DenseTensor
*
dx
)
{
auto
*
dx_data
=
dx
->
data
<
T
>
();
auto
dims
=
out
.
dims
();
const
int
rank
=
dims
.
size
(
);
const
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
input_axis
,
rank
);
const
int
dim
=
dims
[
axis
];
const
int
N
=
phi
::
funcs
::
SizeToAxis
(
axis
,
dims
)
;
const
int
D
=
phi
::
funcs
::
SizeOutAxis
(
axis
,
dims
)
;
int
rank
=
out
.
dims
().
size
();
int
axis
=
phi
::
funcs
::
CanonicalAxis
(
input_axis
,
rank
);
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
out
.
dims
(),
axis
);
int
N
=
tensor_dims
[
0
];
int
dim
=
tensor_dims
[
1
]
;
int
D
=
tensor_dims
[
2
]
;
constexpr
int
max_dim
=
512
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
const
int
kDimLog2
=
log2_ceil
(
dim
);
const
int
kDimCeil
=
1
<<
kDimLog2
;
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
int
batches_per_warp
=
(
kDimCeil
<=
128
)
?
2
:
1
;
if
(
D
==
1
&&
(
!
CanUseCudnnSoftmax
<
T
>
(
dev_ctx
)
||
(
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)))
{
int
dim_log2
=
Log2Ceil
(
dim
);
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
int
batches_per_warp
=
(
dim_ceil
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
kWarpS
ize
);
int
warps_per_block
=
(
threads_per_block
/
warp_s
ize
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
kWarpS
ize
,
warps_per_block
,
1
);
dim3
threads
(
warp_s
ize
,
warps_per_block
,
1
);
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
...
...
@@ -921,7 +982,7 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
else
if
(
dim
%
2
==
0
)
{
SwitchWarpSoftmaxBackward
<
T
,
T2
,
LogMode
>
(
blocks
,
threads
,
...
...
@@ -932,7 +993,7 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
else
{
SwitchWarpSoftmaxBackward
<
T
,
T
,
LogMode
>
(
blocks
,
threads
,
...
...
@@ -943,88 +1004,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
N
,
dim
,
dim
,
kDimL
og2
);
dim_l
og2
);
}
}
else
if
(
D
>
1
)
{
LaunchNormalSoftmaxBackward
<
T
,
LogMode
>
(
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
D
);
}
else
{
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#else
cudnnTensorDescriptor_t
desc_
=
desc
.
descriptor
<
T
>
(
layout
,
tensor_dims
);
#endif
auto
handle
=
dev_ctx
.
cudnn_handle
();
#ifdef PADDLE_WITH_HIP
auto
mode
=
axis
==
rank
-
1
?
MIOPEN_SOFTMAX_MODE_INSTANCE
:
MIOPEN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_LOG
,
mode
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
miopenSoftmaxBackward_V2
(
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
,
MIOPEN_SOFTMAX_ACCURATE
,
mode
));
}
#else
auto
mode
=
axis
==
rank
-
1
?
CUDNN_SOFTMAX_MODE_INSTANCE
:
CUDNN_SOFTMAX_MODE_CHANNEL
;
if
(
LogMode
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_LOG
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cudnnSoftmaxBackward
(
handle
,
CUDNN_SOFTMAX_ACCURATE
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc_
,
out
.
data
<
T
>
(),
desc_
,
dout
.
data
<
T
>
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc_
,
dx_data
));
}
#endif
SoftmaxBackwardCudnnKernel
<
T
>
(
dev_ctx
,
out
,
dout
,
axis
,
LogMode
,
dx
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录