diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index ee2beded713085ed5232b9f05f4250a73315c978..61efa409b90c3ed7bcedffbd08896ab13ec2b74c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -34,9 +34,11 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/fast_divmod.h" // Reduce split or not, Whether to use ReduceHigherDim #define REDUCE_SPLIT_BOUNDARY 512 +#define REDUCE_VEC_SIZE 4 namespace paddle { namespace operators { @@ -72,6 +74,8 @@ static inline int GetLastPow2(int n) { return std::max(1, n - (n >> 1)); } +static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; } + // get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny static inline std::vector GetDimStrides(const std::vector& dims, const std::vector& idx) { @@ -122,10 +126,10 @@ static inline void CheckReduceRank(int reduce_rank, int rank) { template static inline paddle::framework::Array VectorToArray( const VectorLikeType& vec) { - PADDLE_ENFORCE_EQ(vec.size(), ElementCount, + PADDLE_ENFORCE_LE(vec.size(), ElementCount, platform::errors::InvalidArgument( "Cub reduce Array: size not match. Received " - "vec.size() %d != ElementCount %d.", + "vec.size() %d > ElementCount %d.", vec.size(), ElementCount)); size_t n = static_cast(vec.size()); paddle::framework::Array ret; @@ -138,6 +142,7 @@ static inline paddle::framework::Array VectorToArray( } // namespace detail using Tensor = framework::Tensor; +constexpr int kMaxRank = framework::DDim::kMaxRank; enum ReduceType { kReduceAll = 0x00, // when reduce_rank == x_rank @@ -146,6 +151,41 @@ enum ReduceType { kReduceAny = 0x03, // when reduce_dim.size() > 1 }; +struct IndexCalculator { + IndexCalculator(int dim, const std::vector& cal_dims, + const std::vector& cal_strides, + const std::vector& full_strides) + : dim(dim) { + dims = detail::VectorToArray(cal_dims); + strides = detail::VectorToArray(full_strides); + std::vector cal_divmoders; + // fast divmod + for (auto i : cal_strides) { + cal_divmoders.push_back(FastDivMod(i)); + } + divmoders = detail::VectorToArray(cal_divmoders); + } + + __device__ inline int Get(int offset) const { + int index = 0; +#pragma unroll + for (int i = 0; i < kMaxRank; ++i) { + if (i == dim) { + break; + } + auto divmod = divmoders[i].Divmod(offset); + index += (divmod.val[0] * strides[dims[i]]); + offset = divmod.val[1]; + } + return index; + } + + int dim; + framework::Array dims; + framework::Array strides; + framework::Array divmoders; +}; + // reduce config template struct ReduceConfig { @@ -264,6 +304,9 @@ struct ReduceConfig { } left_dim.assign(left_set.begin(), left_set.end()); + + // if the last dim gets involved in reduction + reduce_lastdim = (reduce_dim.back() == x_dim.size() - 1); } // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny @@ -300,20 +343,76 @@ struct ReduceConfig { if (rank == reduce_rank) { reduce_type = static_cast(ReduceType::kReduceAll); - } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { reduce_type = static_cast(ReduceType::kReduceLastDim); - } else if (reduce_rank == 1 && ((rank == 2 && is_large_enough) || rank != 2)) { // ReduceFirstDim and reduceSecondDim reduce_type = static_cast(ReduceType::kReduceHigherDim); - } else { reduce_type = static_cast(ReduceType::kReduceAny); } } + void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) { + constexpr int min_reduce_num_per_thread = 16; + constexpr int max_reduce_num_per_thread = 256; + constexpr int max_num_threads = detail::kMaxThread; + + // set block size. + // 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y; + // 2. if reduce_lastdim == false, block is 2-D, if it is necessary, + // it should reduce in block y. + int grid_num, reduce_num_per_thread; + if (reduce_lastdim) { + block_dim->x = detail::GetBlockDim(reduce_num); + block_dim->y = 1; + grid_num = left_num; + reduce_num_per_thread = + detail::AlignUp(reduce_num, block_dim->x * block_dim->y); + } else { + int block_x = detail::GetBlockDim(left_num); + int block_y = detail::GetBlockDim(reduce_num); + block_dim->x = std::min(block_x, 32); + block_dim->y = + std::min(block_y, static_cast(max_num_threads / block_dim->x)); + block_dim->x = + std::min(block_x, static_cast(max_num_threads / block_dim->y)); + grid_num = detail::AlignUp(left_num, block_dim->x); + reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->y); + } + int device_id = platform::GetCurrentDeviceId(); + int max_mp = platform::GetCUDAMultiProcessors(device_id); + int max_threads_per_mp = + platform::GetCUDAMaxThreadsPerMultiProcessor(device_id); + int max_threads = max_threads_per_mp * max_mp; + int num_threads = block_dim->x * block_dim->y; + int max_num_blocks = max_threads / num_threads; + + // set grid size. + // Whether to set grid.y larger than 1, there are 3 following rules: + // 1. The number that each thread process should no less than + // min_reduce_num_per_threadbut no more than max_reduce_num_per_thread; + // 2. It should maximize the utilization of SM. + // So we choose the minimum between input_split_num_1 and input_split_num_3 + // to make each thread process as mush data as possible. Meanwhile, + // the number cannot be larger than max_reduce_num_per_thread, so we + // choose the maximum between the result above and input_split_num_2. + int input_split_num_1 = + detail::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread); + int input_split_num_2 = + detail::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread); + int input_split_num_3 = detail::AlignUp(max_num_blocks, grid_num); + + grid_dim->x = grid_num; + grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3), + input_split_num_2); + // if grid.y > 1, we need launch reduce kernel again. + if (grid_dim->y > 1) { + should_reduce_again = true; + } + } + // set block and grid for launch kernel // for ReduceHigherDim: if block is enough -> splite reduce_num // else init block(32, 1) grid(block_num, 1) @@ -368,6 +467,8 @@ struct ReduceConfig { grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.y = 1; } + } else if (reduce_type == ReduceType::kReduceAny) { + SetBlockDimForReduceAny(&block_dim, &grid_dim); } block = block_dim; @@ -388,6 +489,7 @@ struct ReduceConfig { int left_num; int blocking_size; bool should_reduce_again; + bool reduce_lastdim; Ty* output_data; @@ -395,8 +497,12 @@ struct ReduceConfig { dim3 grid; }; +static __device__ int SharedMemoryIndex(int index) { + return (threadIdx.y + index) * blockDim.x + threadIdx.x; +} + template -__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { +static __device__ T WarpReduce(T val, ReduceOp reducer) { unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) { @@ -416,7 +522,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { * res to warp0 and process the second WarpReduce */ template -__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) { +static __device__ T BlockXReduce(T val, ReduceOp reducer) { using detail::kWarpSize; __shared__ T shared[kWarpSize]; int block_dim_x = blockDim.x; @@ -441,14 +547,28 @@ __device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) { return val; } +template +static __device__ T BlockYReduce(T val, ReduceOp reducer) { + __shared__ T shared_memory[detail::kMaxThread]; + shared_memory[SharedMemoryIndex(0)] = val; + for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { + __syncthreads(); + if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) { + T temp = shared_memory[SharedMemoryIndex(stride)]; + val = reducer(val, temp); + } + shared_memory[SharedMemoryIndex(0)] = val; + } + return val; +} + // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this // function will be used // blockId.x -> left_num, threadId.x -> reduce_num template -__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, - ReduceOp reducer, - TransformOp transformer, Ty init, - int reduce_num) { +__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, + int reduce_num) { int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; Ty reduce_var = init; @@ -458,7 +578,7 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, } __syncthreads(); - reduce_var = BlockReduce(reduce_var, reducer); + reduce_var = BlockXReduce(reduce_var, reducer); if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -471,11 +591,9 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 template -__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, - ReduceOp reducer, - TransformOp transformer, - Ty init, int reduce_num, - int left_num, int block_size) { +__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, + int reduce_num, int left_num, int block_size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int idy = blockIdx.y * block_size; @@ -497,71 +615,97 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // function will be used -// blockId.x -> left_num, threadId.x -> reduce_num -template -__device__ __forceinline__ void ReduceAny( - const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, - int reduce_num, paddle::framework::Array x_strides, - paddle::framework::Array reduce_dim, - paddle::framework::Array reduce_strides, - paddle::framework::Array left_dim, - paddle::framework::Array left_strides) { - int sub_index[Rank]; - int left_idx = blockIdx.x; - for (int i = 0; i < Rank - ReduceRank; ++i) { - sub_index[left_dim[i]] = left_idx / left_strides[i]; - left_idx %= left_strides[i]; - } - - int reduce_idx = threadIdx.x; - for (int j = 0; j < ReduceRank; ++j) { - sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; - reduce_idx %= reduce_strides[j]; - } - - int idx_x = 0; - for (int k = 0; k < Rank; ++k) { - idx_x += (sub_index[k] * x_strides[k]); +template +__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, int reduce_num, + int left_num, bool reduce_lastdim, + const IndexCalculator& reduce_index_calculator, + const IndexCalculator& left_index_calculator) { + int input_idx, left_idx, stride; + // the last dim gets involved in reduction + if (reduce_lastdim) { + input_idx = blockIdx.y * blockDim.x + threadIdx.x; + left_idx = blockIdx.x; + stride = gridDim.y * blockDim.x; + } else { + input_idx = blockIdx.y * blockDim.y + threadIdx.y; + left_idx = blockIdx.x * blockDim.x + threadIdx.x; + stride = gridDim.y * blockDim.y; } - Ty reduce_var = static_cast(transformer(x[idx_x])); - - for (int i = threadIdx.x + blockDim.x; i < reduce_num; i += blockDim.x) { - int reduce_idx = i; + // calculate the offset, means the addr where each thread really start. + int input_offset = left_index_calculator.Get(left_idx); + const Tx* input = x + input_offset; + Ty reduce_var = init; - for (int j = 0; j < ReduceRank; ++j) { - sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; - reduce_idx %= reduce_strides[j]; + // 1. reduce for each thread + if (left_idx < left_num) { + // load REDUCE_VEC_SIZE data once, and then compute + Tx input_reg[REDUCE_VEC_SIZE]; + int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride; + while (input_idx < bound) { +#pragma unroll + for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { + int reduce_idx = input_idx + i * stride; + int idx_x = reduce_index_calculator.Get(reduce_idx); + input_reg[i] = input[idx_x]; + } +#pragma unroll + for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { + reduce_var = reducer(reduce_var, transformer(input_reg[i])); + } + input_idx += REDUCE_VEC_SIZE * stride; } - int idx_x = 0; - for (int k = 0; k < Rank; ++k) { - idx_x += (sub_index[k] * x_strides[k]); + // deal with the remain part + int input_idx_tmp = input_idx; +#pragma unroll + for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { + if (input_idx >= reduce_num) { + break; + } + int reduce_idx = input_idx; + int idx_x = reduce_index_calculator.Get(reduce_idx); + input_reg[i] = input[idx_x]; + input_idx += stride; } + input_idx = input_idx_tmp; +#pragma unroll + for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { + if (input_idx >= reduce_num) { + break; + } + reduce_var = reducer(reduce_var, transformer(input_reg[i])); + input_idx += stride; + } + } - reduce_var = static_cast( - reducer(reduce_var, static_cast(transformer(x[idx_x])))); + // 2. reduce in block y + if (blockDim.y > 1) { + reduce_var = BlockYReduce(reduce_var, reducer); } __syncthreads(); - reduce_var = BlockReduce(reduce_var, reducer); - if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; + if (reduce_lastdim) { + // 3. reduce in block x + reduce_var = BlockXReduce(reduce_var, reducer); + if (threadIdx.x == 0) { + y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var; + } + } else { + if (left_idx < left_num && threadIdx.y == 0) { + y[blockIdx.y * left_num + left_idx] = reduce_var; + } } } // module function designed for global function -template -__device__ __forceinline__ void ReduceModule( - const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, - int reduce_num, int left_num, int blocking_size, int reduce_type, - paddle::framework::Array x_strides, - paddle::framework::Array reduce_dim, - paddle::framework::Array reduce_strides, - paddle::framework::Array left_dim, - paddle::framework::Array left_strides) { - // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1 +template +__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, int reduce_num, + int left_num, int blocking_size, int reduce_type, + bool reduce_lastdim, + const IndexCalculator& reduce_index_calculator, + const IndexCalculator& left_index_calculator) { if (reduce_type == ReduceType::kReduceLastDim) { ReduceLastDim(x, y, reducer, transformer, init, reduce_num); @@ -573,104 +717,66 @@ __device__ __forceinline__ void ReduceModule( // reduce_rank >= 2 } else { - ReduceAny( - x, y, reducer, transformer, reduce_num, x_strides, reduce_dim, - reduce_strides, left_dim, left_strides); + ReduceAny( + x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, + reduce_index_calculator, left_index_calculator); } } -template -__global__ void ReduceKernelFunction( - const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, - int reduce_num, int left_num, int block_size, int reduce_type, - paddle::framework::Array x_strides, - paddle::framework::Array reduce_dim, - paddle::framework::Array reduce_strides, - paddle::framework::Array left_dim, - paddle::framework::Array left_strides) { - ReduceModule( - x, y, reducer, transformer, init, reduce_num, left_num, block_size, - reduce_type, x_strides, reduce_dim, reduce_strides, left_dim, - left_strides); +template +__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, + int reduce_num, int left_num, + int blocking_size, int reduce_type, + bool reduce_lastdim, + IndexCalculator reduce_index_calculator, + IndexCalculator left_index_calculator) { + ReduceModule( + x, y, reducer, transformer, init, reduce_num, left_num, blocking_size, + reduce_type, reduce_lastdim, reduce_index_calculator, + left_index_calculator); } -template +template static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, Ty init, gpuStream_t stream, ReduceConfig config) { using TransformOp = typename ReduceOp::Transformer; - ReduceKernelFunction<<>>( + int reduce_rank = config.reduce_strides.size(); + int left_rank = config.left_strides.size(); + auto reduce_index_calculator = IndexCalculator( + reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides); + auto left_index_calculator = IndexCalculator( + left_rank, config.left_dim, config.left_strides, config.x_strides); + + ReduceKernelFunction<<>>( x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, config.reduce_num, config.left_num, config.blocking_size, - config.reduce_type, detail::VectorToArray(config.x_strides), - detail::VectorToArray(config.reduce_dim), - detail::VectorToArray(config.reduce_strides), - detail::VectorToArray(config.left_dim), - detail::VectorToArray(config.left_strides)); + config.reduce_type, config.reduce_lastdim, reduce_index_calculator, + left_index_calculator); if (config.should_reduce_again) { - dim3 block(config.block.x, 1, 1); - dim3 grid(config.grid.x, 1, config.grid.z); + dim3 block; + dim3 grid; + if (config.reduce_lastdim) { + block = dim3(32, 1, 1); + grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1); + } else { + block = dim3(config.block.x, 1, 1); + grid = dim3(config.grid.x, 1, config.grid.z); + } - ReduceKernelFunction, Rank, - ReduceRank><<>>( + ReduceKernelFunction><<>>( config.output_data, y_data, reducer, detail::IdentityFunctor(config.grid.y), init, config.grid.y, config.left_num, config.grid.y, ReduceType::kReduceHigherDim, - detail::VectorToArray(config.x_strides), - detail::VectorToArray(config.reduce_dim), - detail::VectorToArray(config.reduce_strides), - detail::VectorToArray(config.left_dim), - detail::VectorToArray(config.left_strides)); + config.reduce_lastdim, reduce_index_calculator, left_index_calculator); } } -template -static void ReduceKernelImpl(const Tx* x_data, Ty* y_data, - const ReduceOp& reducer, Ty init, - gpuStream_t stream, ReduceConfig config) { - int reduce_rank = config.reduce_strides.size(); - int rank = config.x_strides.size(); - -#define CUB_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto Rank = i; \ - switch (reduce_rank) { __VA_ARGS__; } \ - } break - -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto ReduceRank = i; \ - LaunchReduceKernel( \ - x_data, y_data, reducer, init, stream, config); \ - } break - - detail::CheckReduceRank(reduce_rank, rank); - switch (rank) { - CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); - - CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);); - - CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2);); - - CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3);); - - CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3);); - - CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); - - CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4);); - - CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5);); - } - -#undef CUB_REDUCE_RANK_CASE -#undef CUB_RANK_CASE -} - template class ReduceOp> void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, @@ -682,8 +788,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, - // temp_output should be stored temp_data in output_data space or stored in - // y_data; + // temp_output should be stored temp_data in output_data space or stored in + // y_data; framework::Tensor tmp; auto x_data = x.data(); auto y_data = y->mutable_data(x.place()); @@ -718,8 +824,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, return; } - ReduceKernelImpl>(x_data, y_data, reducer, - reducer.initial(), stream, config); + LaunchReduceKernel>( + x_data, y_data, reducer, reducer.initial(), stream, config); } template class ReduceOp> diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index 5c5903d62cd277f6d86e542f04de066c96374704..c6c22bb2f9203b00e924f06f6fe4bf1b0b4ffc65 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -54,7 +54,7 @@ struct FastDivMod { return (t + n) >> shift_val; } - __device__ __forceinline__ DivModT Divmod(uint32_t n) { + __device__ __forceinline__ DivModT Divmod(uint32_t n) const { uint32_t q = Div(n); DivModT result = {q, n - q * divisor}; return result;