未验证 提交 7a476608 编写于 作者: Z Zhang Zheng 提交者: GitHub

Reduce build time by deleting the template param BlockDim (#33901)

上级 70ecf3b1
......@@ -33,6 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cuda_device_function.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
......@@ -86,8 +87,10 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
#ifdef __HIPCC__
constexpr int kMaxThread = 256;
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kWarpSize = 32;
#endif
// get blockDim for reduceLastDim and reduceAny
......@@ -392,27 +395,70 @@ struct ReduceConfig {
dim3 grid;
};
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
using detail::kWarpSize;
__shared__ T shared[kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int wid = threadIdx.x / kWarpSize;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[lane];
}
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = 1; stride < block_dim_x; stride <<= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim>
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
ReduceOp reducer,
TransformOp transformer, Ty init,
int reduce_num) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) {
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
}
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = BlockReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
......@@ -453,7 +499,7 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
int Rank, int ReduceRank>
__device__ __forceinline__ void ReduceAny(
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer,
int reduce_num, paddle::framework::Array<int, Rank> x_strides,
......@@ -461,8 +507,6 @@ __device__ __forceinline__ void ReduceAny(
paddle::framework::Array<int, ReduceRank> reduce_strides,
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
int sub_index[Rank];
int left_idx = blockIdx.x;
for (int i = 0; i < Rank - ReduceRank; ++i) {
......@@ -482,7 +526,7 @@ __device__ __forceinline__ void ReduceAny(
}
Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
for (int i = threadIdx.x + blockDim.x; i < reduce_num; i += blockDim.x) {
int reduce_idx = i;
for (int j = 0; j < ReduceRank; ++j) {
......@@ -500,9 +544,7 @@ __device__ __forceinline__ void ReduceAny(
}
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = BlockReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
}
......@@ -510,7 +552,7 @@ __device__ __forceinline__ void ReduceAny(
// module function designed for global function
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
int Rank, int ReduceRank>
__device__ __forceinline__ void ReduceModule(
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
int reduce_num, int left_num, int blocking_size, int reduce_type,
......@@ -521,8 +563,8 @@ __device__ __forceinline__ void ReduceModule(
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
// reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
if (reduce_type == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp, BlockDim>(
x, y, reducer, transformer, init, reduce_num);
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
init, reduce_num);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) {
......@@ -531,14 +573,14 @@ __device__ __forceinline__ void ReduceModule(
// reduce_rank >= 2
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
ReduceAny<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>(
x, y, reducer, transformer, reduce_num, x_strides, reduce_dim,
reduce_strides, left_dim, left_strides);
}
}
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
int Rank, int ReduceRank>
__global__ void ReduceKernelFunction(
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init,
int reduce_num, int left_num, int block_size, int reduce_type,
......@@ -547,47 +589,46 @@ __global__ void ReduceKernelFunction(
paddle::framework::Array<int, ReduceRank> reduce_strides,
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
ReduceModule<Tx, Ty, ReduceOp, TransformOp, BlockDim, Rank, ReduceRank>(
ReduceModule<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>(
x, y, reducer, transformer, init, reduce_num, left_num, block_size,
reduce_type, x_strides, reduce_dim, reduce_strides, left_dim,
left_strides);
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp, int kRank,
int kReduceRank>
template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer;
ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank,
kReduceRank><<<config.grid, config.block, 0, stream>>>(
ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, Rank,
ReduceRank><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
config.reduce_num, config.left_num, config.blocking_size,
config.reduce_type, detail::VectorToArray<int, kRank>(config.x_strides),
detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
config.reduce_type, detail::VectorToArray<int, Rank>(config.x_strides),
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
if (config.should_reduce_again) {
dim3 block(config.block.x, 1, 1);
dim3 grid(config.grid.x, 1, config.grid.z);
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128,
kRank, kReduceRank><<<grid, block, 0, stream>>>(
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank,
ReduceRank><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
detail::VectorToArray<int, kRank>(config.x_strides),
detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
detail::VectorToArray<int, Rank>(config.x_strides),
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
}
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp>
template <typename Tx, typename Ty, typename ReduceOp>
static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) {
......@@ -596,15 +637,15 @@ static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
constexpr auto Rank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
LaunchReduceKernel<Tx, Ty, BlockDim, ReduceOp, kRank, kReduceRank>( \
x_data, y_data, reducer, init, stream, config); \
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto ReduceRank = i; \
LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
x_data, y_data, reducer, init, stream, config); \
} break
detail::CheckReduceRank(reduce_rank, rank);
......@@ -677,24 +718,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
return;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
ReduceKernelImpl<Tx, Ty, block_dim, ReduceOp<Tx, Ty>>( \
x_data, y_data, reducer, reducer.initial(), stream, config); \
} break
switch (detail::GetBlockDim(config.reduce_num)) {
CUB_BLOCK_DIM_CASE(256);
CUB_BLOCK_DIM_CASE(128);
CUB_BLOCK_DIM_CASE(64);
CUB_BLOCK_DIM_CASE(32);
CUB_BLOCK_DIM_CASE(16);
CUB_BLOCK_DIM_CASE(8);
CUB_BLOCK_DIM_CASE(4);
CUB_BLOCK_DIM_CASE(2);
}
#undef CUB_BLOCK_DIM_CASE
ReduceKernelImpl<Tx, Ty, ReduceOp<Tx, Ty>>(x_data, y_data, reducer,
reducer.initial(), stream, config);
}
template <typename Tx, template <typename, typename> class ReduceOp>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册