Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0a21924a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a21924a
编写于
1月 10, 2021
作者:
G
GaoWei8
提交者:
GitHub
1月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize softmax forward (#30217)
* optimize softmax forward
上级
af80859d
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
168 addition
and
18 deletion
+168
-18
paddle/fluid/operators/softmax_cudnn_op.cu
paddle/fluid/operators/softmax_cudnn_op.cu
+168
-18
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu
浏览文件 @
0a21924a
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -31,6 +32,13 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
...
@@ -31,6 +32,13 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using
DataLayout
=
platform
::
DataLayout
;
using
DataLayout
=
platform
::
DataLayout
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \
case Log2Elements: \
WarpSoftmaxForward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
out_data, x->data<T>(), N, dim, dim); \
break;
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
static
inline
int
SizeOutAxis
(
const
int
axis
,
DDim
dims
)
{
int
size
=
1
;
int
size
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
for
(
int
i
=
axis
+
1
;
i
<
dims
.
size
();
i
++
)
{
...
@@ -39,6 +47,12 @@ static inline int SizeOutAxis(const int axis, DDim dims) {
...
@@ -39,6 +47,12 @@ static inline int SizeOutAxis(const int axis, DDim dims) {
return
size
;
return
size
;
}
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
,
int
VLEN
>
template
<
typename
T
,
int
VLEN
>
union
vec_t
{
union
vec_t
{
static_assert
(
sizeof
(
T
)
==
-
1
,
"vec_t is only available by specialization."
);
static_assert
(
sizeof
(
T
)
==
-
1
,
"vec_t is only available by specialization."
);
...
@@ -84,6 +98,107 @@ __global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size,
...
@@ -84,6 +98,107 @@ __global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size,
reinterpret_cast
<
VECT
*>
(
&
dst
[
offset
+
idx
])[
0
]
=
buf
;
reinterpret_cast
<
VECT
*>
(
&
dst
[
offset
+
idx
])[
0
]
=
buf
;
}
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
__device__
__forceinline__
void
warp_reduce_sum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
T
sum_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
WARP_BATCH
,
int
WARP_SIZE_SOFTMAX
>
__device__
__forceinline__
void
warp_reduce_max
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
WARP_SIZE_SOFTMAX
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
T
max_val
=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
template
<
typename
T
,
typename
AccT
,
int
Log2Elements
>
__global__
void
WarpSoftmaxForward
(
T
*
dst
,
const
T
*
src
,
const
int
batch_size
,
const
int
stride
,
const
int
element_count
)
{
constexpr
int
next_power_of_two
=
1
<<
Log2Elements
;
constexpr
int
warp_size_softmax
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
warp_size_softmax
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
{
local_batches
=
WARP_BATCH
;
}
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
// load data from global memory
AccT
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
static_cast
<
float
>
(
src
[
i
*
element_count
+
it
*
warp_size_softmax
]);
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
// compute max_value
AccT
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce_max
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
max_value
);
AccT
sum
[
WARP_BATCH
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
(
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
])));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce_sum
<
AccT
,
WARP_BATCH
,
warp_size_softmax
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
warp_size_softmax
;
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
warp_size_softmax
]
=
elements
[
i
][
it
]
/
sum
[
i
];
}
else
{
break
;
}
}
}
}
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
template
<
typename
T
,
int
VPT
,
int
WARP_PER_BLOCK
>
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
__global__
void
VecSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
const
int
batch_size
,
const
int
batch_size
,
...
@@ -130,26 +245,61 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
...
@@ -130,26 +245,61 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
N
=
SizeToAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
const
int
D
=
SizeOutAxis
(
axis
,
dims
);
constexpr
int
max_dim
=
320
;
bool
optimize
=
false
;
constexpr
int
warps_per_block
=
4
;
constexpr
int
warps_per_block
=
4
;
if
(
D
==
1
&&
dim
==
128
&&
N
%
warps_per_block
==
0
&&
sizeof
(
T
)
<=
4
)
{
if
(
D
==
1
&&
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)
{
if
(
dim
==
128
&&
N
%
warps_per_block
==
0
)
{
optimize
=
true
;
// a warp for a batch, 4 elements for a thread, only support the softmax
// a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently
// dim size = 128 currently
if
(
sizeof
(
T
)
==
2
)
{
if
(
sizeof
(
T
)
==
2
)
{
VecSoftmaxForward
<
VecSoftmaxForward
<
T
,
int2
,
4
,
warps_per_block
><<<
T
,
int2
,
4
,
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dim
);
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
if
(
sizeof
(
T
)
==
4
)
{
}
else
if
(
sizeof
(
T
)
==
4
)
{
VecSoftmaxForward
<
VecSoftmaxForward
<
T
,
int4
,
4
,
warps_per_block
><<<
T
,
int4
,
4
,
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
0
,
warps_per_block
><<<
N
/
warps_per_block
,
warps_per_block
*
WARP_SIZE
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
out_data
,
x
->
data
<
T
>
(),
N
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
dim
);
out_data
,
x
->
data
<
T
>
(),
N
,
dim
);
}
else
{
}
else
{
assert
(
false
&&
"not support"
);
assert
(
false
&&
"not support"
);
}
}
}
else
{
}
else
if
(
dim
<
max_dim
)
{
optimize
=
true
;
int
log2_elements
=
static_cast
<
int
>
(
log2_ceil
(
dim
));
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
N
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
log2_elements
)
{
LAUNCH_SOFTMAX_WARP_FORWARD
(
0
);
// 1
LAUNCH_SOFTMAX_WARP_FORWARD
(
1
);
// 2
LAUNCH_SOFTMAX_WARP_FORWARD
(
2
);
// 4
LAUNCH_SOFTMAX_WARP_FORWARD
(
3
);
// 8
LAUNCH_SOFTMAX_WARP_FORWARD
(
4
);
// 16
LAUNCH_SOFTMAX_WARP_FORWARD
(
5
);
// 32
LAUNCH_SOFTMAX_WARP_FORWARD
(
6
);
// 64
LAUNCH_SOFTMAX_WARP_FORWARD
(
7
);
// 128
LAUNCH_SOFTMAX_WARP_FORWARD
(
8
);
// 256
LAUNCH_SOFTMAX_WARP_FORWARD
(
9
);
// 512
default:
break
;
}
}
}
if
(
!
optimize
)
{
ScopedTensorDescriptor
desc
;
ScopedTensorDescriptor
desc
;
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
std
::
vector
<
int
>
tensor_dims
=
{
N
,
dim
,
D
,
1
};
DataLayout
layout
=
DataLayout
::
kNCHW
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录