diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 45279a224ac8dc6895f35427496de37ec6279198..ee2beded713085ed5232b9f05f4250a73315c978 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -33,6 +33,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/cuda_device_function.h" // Reduce split or not, Whether to use ReduceHigherDim #define REDUCE_SPLIT_BOUNDARY 512 @@ -86,8 +87,10 @@ static inline std::vector GetDimStrides(const std::vector& dims, #ifdef __HIPCC__ constexpr int kMaxThread = 256; +constexpr int kWarpSize = 64; #else constexpr int kMaxThread = 128; +constexpr int kWarpSize = 32; #endif // get blockDim for reduceLastDim and reduceAny @@ -392,27 +395,70 @@ struct ReduceConfig { dim3 grid; }; +template +__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) { + T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + val = reducer(val, temp); + } + return val; +} + +/* e.g. + * |---------block---------| + * |warp0|warp1|warp2|warp3| + * |0~31|32~63|64~95|96~127| ---->blockDim.x = 128 + * \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp + * res0 res1 res2 res3 ---->2. Store result of each warp to shared memory + * \ \ / / ---->3. Load the result above from shared memory + * res to warp0 and process the second WarpReduce + */ +template +__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) { + using detail::kWarpSize; + __shared__ T shared[kWarpSize]; + int block_dim_x = blockDim.x; + if (blockDim.x > kWarpSize) { + block_dim_x = blockDim.x / kWarpSize; + int lane = threadIdx.x % kWarpSize; + int wid = threadIdx.x / kWarpSize; + val = WarpReduce(val, reducer); + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + val = shared[lane]; + } + + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int stride = 1; stride < block_dim_x; stride <<= 1) { + T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + val = reducer(val, temp); + } + return val; +} + // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this // function will be used // blockId.x -> left_num, threadId.x -> reduce_num -template +template __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int reduce_num) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; Ty reduce_var = init; - for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) { + for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) { reduce_var = reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); } __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + reduce_var = BlockReduce(reduce_var, reducer); if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -453,7 +499,7 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, // function will be used // blockId.x -> left_num, threadId.x -> reduce_num template + int Rank, int ReduceRank> __device__ __forceinline__ void ReduceAny( const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, int reduce_num, paddle::framework::Array x_strides, @@ -461,8 +507,6 @@ __device__ __forceinline__ void ReduceAny( paddle::framework::Array reduce_strides, paddle::framework::Array left_dim, paddle::framework::Array left_strides) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; - int sub_index[Rank]; int left_idx = blockIdx.x; for (int i = 0; i < Rank - ReduceRank; ++i) { @@ -482,7 +526,7 @@ __device__ __forceinline__ void ReduceAny( } Ty reduce_var = static_cast(transformer(x[idx_x])); - for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { + for (int i = threadIdx.x + blockDim.x; i < reduce_num; i += blockDim.x) { int reduce_idx = i; for (int j = 0; j < ReduceRank; ++j) { @@ -500,9 +544,7 @@ __device__ __forceinline__ void ReduceAny( } __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); - + reduce_var = BlockReduce(reduce_var, reducer); if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; } @@ -510,7 +552,7 @@ __device__ __forceinline__ void ReduceAny( // module function designed for global function template + int Rank, int ReduceRank> __device__ __forceinline__ void ReduceModule( const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int reduce_num, int left_num, int blocking_size, int reduce_type, @@ -521,8 +563,8 @@ __device__ __forceinline__ void ReduceModule( paddle::framework::Array left_strides) { // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1 if (reduce_type == ReduceType::kReduceLastDim) { - ReduceLastDim( - x, y, reducer, transformer, init, reduce_num); + ReduceLastDim(x, y, reducer, transformer, + init, reduce_num); // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 } else if (reduce_type == ReduceType::kReduceHigherDim) { @@ -531,14 +573,14 @@ __device__ __forceinline__ void ReduceModule( // reduce_rank >= 2 } else { - ReduceAny( + ReduceAny( x, y, reducer, transformer, reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, left_strides); } } template + int Rank, int ReduceRank> __global__ void ReduceKernelFunction( const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int reduce_num, int left_num, int block_size, int reduce_type, @@ -547,47 +589,46 @@ __global__ void ReduceKernelFunction( paddle::framework::Array reduce_strides, paddle::framework::Array left_dim, paddle::framework::Array left_strides) { - ReduceModule( + ReduceModule( x, y, reducer, transformer, init, reduce_num, left_num, block_size, reduce_type, x_strides, reduce_dim, reduce_strides, left_dim, left_strides); } -template +template static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, Ty init, gpuStream_t stream, ReduceConfig config) { using TransformOp = typename ReduceOp::Transformer; - ReduceKernelFunction<<>>( + ReduceKernelFunction<<>>( x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, config.reduce_num, config.left_num, config.blocking_size, - config.reduce_type, detail::VectorToArray(config.x_strides), - detail::VectorToArray(config.reduce_dim), - detail::VectorToArray(config.reduce_strides), - detail::VectorToArray(config.left_dim), - detail::VectorToArray(config.left_strides)); + config.reduce_type, detail::VectorToArray(config.x_strides), + detail::VectorToArray(config.reduce_dim), + detail::VectorToArray(config.reduce_strides), + detail::VectorToArray(config.left_dim), + detail::VectorToArray(config.left_strides)); if (config.should_reduce_again) { dim3 block(config.block.x, 1, 1); dim3 grid(config.grid.x, 1, config.grid.z); - ReduceKernelFunction, 128, - kRank, kReduceRank><<>>( + ReduceKernelFunction, Rank, + ReduceRank><<>>( config.output_data, y_data, reducer, detail::IdentityFunctor(config.grid.y), init, config.grid.y, config.left_num, config.grid.y, ReduceType::kReduceHigherDim, - detail::VectorToArray(config.x_strides), - detail::VectorToArray(config.reduce_dim), - detail::VectorToArray(config.reduce_strides), - detail::VectorToArray(config.left_dim), - detail::VectorToArray(config.left_strides)); + detail::VectorToArray(config.x_strides), + detail::VectorToArray(config.reduce_dim), + detail::VectorToArray(config.reduce_strides), + detail::VectorToArray(config.left_dim), + detail::VectorToArray(config.left_strides)); } } -template +template static void ReduceKernelImpl(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, Ty init, gpuStream_t stream, ReduceConfig config) { @@ -596,15 +637,15 @@ static void ReduceKernelImpl(const Tx* x_data, Ty* y_data, #define CUB_RANK_CASE(i, ...) \ case i: { \ - constexpr auto kRank = i; \ + constexpr auto Rank = i; \ switch (reduce_rank) { __VA_ARGS__; } \ } break -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kReduceRank = i; \ - LaunchReduceKernel( \ - x_data, y_data, reducer, init, stream, config); \ +#define CUB_REDUCE_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto ReduceRank = i; \ + LaunchReduceKernel( \ + x_data, y_data, reducer, init, stream, config); \ } break detail::CheckReduceRank(reduce_rank, rank); @@ -677,24 +718,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, return; } -#define CUB_BLOCK_DIM_CASE(block_dim) \ - case block_dim: { \ - constexpr auto kBlockDim = block_dim; \ - ReduceKernelImpl>( \ - x_data, y_data, reducer, reducer.initial(), stream, config); \ - } break - - switch (detail::GetBlockDim(config.reduce_num)) { - CUB_BLOCK_DIM_CASE(256); - CUB_BLOCK_DIM_CASE(128); - CUB_BLOCK_DIM_CASE(64); - CUB_BLOCK_DIM_CASE(32); - CUB_BLOCK_DIM_CASE(16); - CUB_BLOCK_DIM_CASE(8); - CUB_BLOCK_DIM_CASE(4); - CUB_BLOCK_DIM_CASE(2); - } -#undef CUB_BLOCK_DIM_CASE + ReduceKernelImpl>(x_data, y_data, reducer, + reducer.initial(), stream, config); } template class ReduceOp>