Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7a476608
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a476608
编写于
7月 05, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
7月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reduce build time by deleting the template param BlockDim (#33901)
上级
70ecf3b1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
85 addition
and
60 deletion
+85
-60
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+85
-60
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
7a476608
...
...
@@ -33,6 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cuda_device_function.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
...
...
@@ -86,8 +87,10 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
#ifdef __HIPCC__
constexpr
int
kMaxThread
=
256
;
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kMaxThread
=
128
;
constexpr
int
kWarpSize
=
32
;
#endif
// get blockDim for reduceLastDim and reduceAny
...
...
@@ -392,27 +395,70 @@ struct ReduceConfig {
dim3
grid
;
};
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
WarpReduce
(
T
val
,
ReduceOp
reducer
)
{
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
detail
::
kWarpSize
/
2
;
stride
>
0
;
stride
>>=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
return
val
;
}
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockReduce
(
T
val
,
ReduceOp
reducer
)
{
using
detail
::
kWarpSize
;
__shared__
T
shared
[
kWarpSize
];
int
block_dim_x
=
blockDim
.
x
;
if
(
blockDim
.
x
>
kWarpSize
)
{
block_dim_x
=
blockDim
.
x
/
kWarpSize
;
int
lane
=
threadIdx
.
x
%
kWarpSize
;
int
wid
=
threadIdx
.
x
/
kWarpSize
;
val
=
WarpReduce
(
val
,
reducer
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
val
=
shared
[
lane
];
}
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
1
;
stride
<
block_dim_x
;
stride
<<=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
return
val
;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
__device__
__forceinline__
void
ReduceLastDim
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
)
{
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_y
=
threadIdx
.
x
;
Ty
reduce_var
=
init
;
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
BlockDim
)
{
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
blockDim
.
x
)
{
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
+
idx_y
])));
}
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
reduce_var
=
BlockReduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
reduce_var
;
...
...
@@ -453,7 +499,7 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
int
Rank
,
int
ReduceRank
>
__device__
__forceinline__
void
ReduceAny
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
int
reduce_num
,
paddle
::
framework
::
Array
<
int
,
Rank
>
x_strides
,
...
...
@@ -461,8 +507,6 @@ __device__ __forceinline__ void ReduceAny(
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
int
sub_index
[
Rank
];
int
left_idx
=
blockIdx
.
x
;
for
(
int
i
=
0
;
i
<
Rank
-
ReduceRank
;
++
i
)
{
...
...
@@ -482,7 +526,7 @@ __device__ __forceinline__ void ReduceAny(
}
Ty
reduce_var
=
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
]));
for
(
int
i
=
threadIdx
.
x
+
BlockDim
;
i
<
reduce_num
;
i
+=
BlockDim
)
{
for
(
int
i
=
threadIdx
.
x
+
blockDim
.
x
;
i
<
reduce_num
;
i
+=
blockDim
.
x
)
{
int
reduce_idx
=
i
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
...
...
@@ -500,9 +544,7 @@ __device__ __forceinline__ void ReduceAny(
}
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
reduce_var
=
BlockReduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
reduce_var
;
}
...
...
@@ -510,7 +552,7 @@ __device__ __forceinline__ void ReduceAny(
// module function designed for global function
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
int
Rank
,
int
ReduceRank
>
__device__
__forceinline__
void
ReduceModule
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
int
left_num
,
int
blocking_size
,
int
reduce_type
,
...
...
@@ -521,8 +563,8 @@ __device__ __forceinline__ void ReduceModule(
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
// reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
if
(
reduce_type
==
ReduceType
::
kReduceLastDim
)
{
ReduceLastDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
);
ReduceLastDim
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
}
else
if
(
reduce_type
==
ReduceType
::
kReduceHigherDim
)
{
...
...
@@ -531,14 +573,14 @@ __device__ __forceinline__ void ReduceModule(
// reduce_rank >= 2
}
else
{
ReduceAny
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
Rank
,
ReduceRank
>
(
ReduceAny
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
Rank
,
ReduceRank
>
(
x
,
y
,
reducer
,
transformer
,
reduce_num
,
x_strides
,
reduce_dim
,
reduce_strides
,
left_dim
,
left_strides
);
}
}
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
int
Rank
,
int
ReduceRank
>
__global__
void
ReduceKernelFunction
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
,
int
left_num
,
int
block_size
,
int
reduce_type
,
...
...
@@ -547,47 +589,46 @@ __global__ void ReduceKernelFunction(
paddle
::
framework
::
Array
<
int
,
ReduceRank
>
reduce_strides
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
paddle
::
framework
::
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
ReduceModule
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
Rank
,
ReduceRank
>
(
ReduceModule
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
Rank
,
ReduceRank
>
(
x
,
y
,
reducer
,
transformer
,
init
,
reduce_num
,
left_num
,
block_size
,
reduce_type
,
x_strides
,
reduce_dim
,
reduce_strides
,
left_dim
,
left_strides
);
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
int
kRank
,
int
kReduceRank
>
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
int
Rank
,
int
ReduceRank
>
static
void
LaunchReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
Ty
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
using
TransformOp
=
typename
ReduceOp
::
Transformer
;
ReduceKernelFunction
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
,
k
Rank
,
k
ReduceRank
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
ReduceKernelFunction
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
,
Rank
,
ReduceRank
><<<
config
.
grid
,
config
.
block
,
0
,
stream
>>>
(
x_data
,
config
.
output_data
,
reducer
,
TransformOp
(
config
.
reduce_num
),
init
,
config
.
reduce_num
,
config
.
left_num
,
config
.
blocking_size
,
config
.
reduce_type
,
detail
::
VectorToArray
<
int
,
k
Rank
>
(
config
.
x_strides
),
detail
::
VectorToArray
<
int
,
k
ReduceRank
>
(
config
.
reduce_dim
),
detail
::
VectorToArray
<
int
,
k
ReduceRank
>
(
config
.
reduce_strides
),
detail
::
VectorToArray
<
int
,
kRank
-
k
ReduceRank
>
(
config
.
left_dim
),
detail
::
VectorToArray
<
int
,
kRank
-
k
ReduceRank
>
(
config
.
left_strides
));
config
.
reduce_type
,
detail
::
VectorToArray
<
int
,
Rank
>
(
config
.
x_strides
),
detail
::
VectorToArray
<
int
,
ReduceRank
>
(
config
.
reduce_dim
),
detail
::
VectorToArray
<
int
,
ReduceRank
>
(
config
.
reduce_strides
),
detail
::
VectorToArray
<
int
,
Rank
-
ReduceRank
>
(
config
.
left_dim
),
detail
::
VectorToArray
<
int
,
Rank
-
ReduceRank
>
(
config
.
left_strides
));
if
(
config
.
should_reduce_again
)
{
dim3
block
(
config
.
block
.
x
,
1
,
1
);
dim3
grid
(
config
.
grid
.
x
,
1
,
config
.
grid
.
z
);
ReduceKernelFunction
<
Ty
,
Ty
,
ReduceOp
,
detail
::
IdentityFunctor
<
Ty
>
,
128
,
kRank
,
k
ReduceRank
><<<
grid
,
block
,
0
,
stream
>>>
(
ReduceKernelFunction
<
Ty
,
Ty
,
ReduceOp
,
detail
::
IdentityFunctor
<
Ty
>
,
Rank
,
ReduceRank
><<<
grid
,
block
,
0
,
stream
>>>
(
config
.
output_data
,
y_data
,
reducer
,
detail
::
IdentityFunctor
<
Ty
>
(
config
.
grid
.
y
),
init
,
config
.
grid
.
y
,
config
.
left_num
,
config
.
grid
.
y
,
ReduceType
::
kReduceHigherDim
,
detail
::
VectorToArray
<
int
,
k
Rank
>
(
config
.
x_strides
),
detail
::
VectorToArray
<
int
,
k
ReduceRank
>
(
config
.
reduce_dim
),
detail
::
VectorToArray
<
int
,
k
ReduceRank
>
(
config
.
reduce_strides
),
detail
::
VectorToArray
<
int
,
kRank
-
k
ReduceRank
>
(
config
.
left_dim
),
detail
::
VectorToArray
<
int
,
kRank
-
k
ReduceRank
>
(
config
.
left_strides
));
detail
::
VectorToArray
<
int
,
Rank
>
(
config
.
x_strides
),
detail
::
VectorToArray
<
int
,
ReduceRank
>
(
config
.
reduce_dim
),
detail
::
VectorToArray
<
int
,
ReduceRank
>
(
config
.
reduce_strides
),
detail
::
VectorToArray
<
int
,
Rank
-
ReduceRank
>
(
config
.
left_dim
),
detail
::
VectorToArray
<
int
,
Rank
-
ReduceRank
>
(
config
.
left_strides
));
}
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
>
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
>
static
void
ReduceKernelImpl
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
ReduceOp
&
reducer
,
Ty
init
,
gpuStream_t
stream
,
ReduceConfig
<
Ty
>
config
)
{
...
...
@@ -596,15 +637,15 @@ static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto
kRank = i;
\
constexpr auto
Rank = i;
\
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...)
\
case i: {
\
constexpr auto
kReduceRank = i;
\
LaunchReduceKernel<Tx, Ty,
BlockDim, ReduceOp, kRank, k
ReduceRank>( \
x_data, y_data, reducer, init, stream, config);
\
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto
ReduceRank = i;
\
LaunchReduceKernel<Tx, Ty,
ReduceOp, Rank,
ReduceRank>( \
x_data, y_data, reducer, init, stream, config); \
} break
detail
::
CheckReduceRank
(
reduce_rank
,
rank
);
...
...
@@ -677,24 +718,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
return
;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
ReduceKernelImpl<Tx, Ty, block_dim, ReduceOp<Tx, Ty>>( \
x_data, y_data, reducer, reducer.initial(), stream, config); \
} break
switch
(
detail
::
GetBlockDim
(
config
.
reduce_num
))
{
CUB_BLOCK_DIM_CASE
(
256
);
CUB_BLOCK_DIM_CASE
(
128
);
CUB_BLOCK_DIM_CASE
(
64
);
CUB_BLOCK_DIM_CASE
(
32
);
CUB_BLOCK_DIM_CASE
(
16
);
CUB_BLOCK_DIM_CASE
(
8
);
CUB_BLOCK_DIM_CASE
(
4
);
CUB_BLOCK_DIM_CASE
(
2
);
}
#undef CUB_BLOCK_DIM_CASE
ReduceKernelImpl
<
Tx
,
Ty
,
ReduceOp
<
Tx
,
Ty
>>
(
x_data
,
y_data
,
reducer
,
reducer
.
initial
(),
stream
,
config
);
}
template
<
typename
Tx
,
template
<
typename
,
typename
>
class
ReduceOp
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录