未验证 提交 2dde0eb0 编写于 作者: Z Zhang Zheng 提交者: GitHub

optimize perfermance of multiple-dimension reduce (#33761)

上级 4d259b91
...@@ -34,9 +34,11 @@ namespace cub = hipcub; ...@@ -34,9 +34,11 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
// Reduce split or not, Whether to use ReduceHigherDim // Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512 #define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -72,6 +74,8 @@ static inline int GetLastPow2(int n) { ...@@ -72,6 +74,8 @@ static inline int GetLastPow2(int n) {
return std::max(1, n - (n >> 1)); return std::max(1, n - (n >> 1));
} }
static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; }
// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny // get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny
static inline std::vector<int> GetDimStrides(const std::vector<int>& dims, static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
const std::vector<int>& idx) { const std::vector<int>& idx) {
...@@ -122,10 +126,10 @@ static inline void CheckReduceRank(int reduce_rank, int rank) { ...@@ -122,10 +126,10 @@ static inline void CheckReduceRank(int reduce_rank, int rank) {
template <typename T, size_t ElementCount, typename VectorLikeType> template <typename T, size_t ElementCount, typename VectorLikeType>
static inline paddle::framework::Array<T, ElementCount> VectorToArray( static inline paddle::framework::Array<T, ElementCount> VectorToArray(
const VectorLikeType& vec) { const VectorLikeType& vec) {
PADDLE_ENFORCE_EQ(vec.size(), ElementCount, PADDLE_ENFORCE_LE(vec.size(), ElementCount,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cub reduce Array: size not match. Received " "Cub reduce Array: size not match. Received "
"vec.size() %d != ElementCount %d.", "vec.size() %d > ElementCount %d.",
vec.size(), ElementCount)); vec.size(), ElementCount));
size_t n = static_cast<size_t>(vec.size()); size_t n = static_cast<size_t>(vec.size());
paddle::framework::Array<T, ElementCount> ret; paddle::framework::Array<T, ElementCount> ret;
...@@ -138,6 +142,7 @@ static inline paddle::framework::Array<T, ElementCount> VectorToArray( ...@@ -138,6 +142,7 @@ static inline paddle::framework::Array<T, ElementCount> VectorToArray(
} // namespace detail } // namespace detail
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
constexpr int kMaxRank = framework::DDim::kMaxRank;
enum ReduceType { enum ReduceType {
kReduceAll = 0x00, // when reduce_rank == x_rank kReduceAll = 0x00, // when reduce_rank == x_rank
...@@ -146,6 +151,41 @@ enum ReduceType { ...@@ -146,6 +151,41 @@ enum ReduceType {
kReduceAny = 0x03, // when reduce_dim.size() > 1 kReduceAny = 0x03, // when reduce_dim.size() > 1
}; };
struct IndexCalculator {
IndexCalculator(int dim, const std::vector<int>& cal_dims,
const std::vector<int>& cal_strides,
const std::vector<int>& full_strides)
: dim(dim) {
dims = detail::VectorToArray<int, kMaxRank>(cal_dims);
strides = detail::VectorToArray<int, kMaxRank>(full_strides);
std::vector<FastDivMod> cal_divmoders;
// fast divmod
for (auto i : cal_strides) {
cal_divmoders.push_back(FastDivMod(i));
}
divmoders = detail::VectorToArray<FastDivMod, kMaxRank>(cal_divmoders);
}
__device__ inline int Get(int offset) const {
int index = 0;
#pragma unroll
for (int i = 0; i < kMaxRank; ++i) {
if (i == dim) {
break;
}
auto divmod = divmoders[i].Divmod(offset);
index += (divmod.val[0] * strides[dims[i]]);
offset = divmod.val[1];
}
return index;
}
int dim;
framework::Array<int, kMaxRank> dims;
framework::Array<int, kMaxRank> strides;
framework::Array<FastDivMod, kMaxRank> divmoders;
};
// reduce config // reduce config
template <typename Ty> template <typename Ty>
struct ReduceConfig { struct ReduceConfig {
...@@ -264,6 +304,9 @@ struct ReduceConfig { ...@@ -264,6 +304,9 @@ struct ReduceConfig {
} }
left_dim.assign(left_set.begin(), left_set.end()); left_dim.assign(left_set.begin(), left_set.end());
// if the last dim gets involved in reduction
reduce_lastdim = (reduce_dim.back() == x_dim.size() - 1);
} }
// set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
...@@ -300,20 +343,76 @@ struct ReduceConfig { ...@@ -300,20 +343,76 @@ struct ReduceConfig {
if (rank == reduce_rank) { if (rank == reduce_rank) {
reduce_type = static_cast<int>(ReduceType::kReduceAll); reduce_type = static_cast<int>(ReduceType::kReduceAll);
} else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim); reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1 && } else if (reduce_rank == 1 &&
((rank == 2 && is_large_enough) || rank != 2)) { ((rank == 2 && is_large_enough) || rank != 2)) {
// ReduceFirstDim and reduceSecondDim // ReduceFirstDim and reduceSecondDim
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim); reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
} else { } else {
reduce_type = static_cast<int>(ReduceType::kReduceAny); reduce_type = static_cast<int>(ReduceType::kReduceAny);
} }
} }
void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) {
constexpr int min_reduce_num_per_thread = 16;
constexpr int max_reduce_num_per_thread = 256;
constexpr int max_num_threads = detail::kMaxThread;
// set block size.
// 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y;
// 2. if reduce_lastdim == false, block is 2-D, if it is necessary,
// it should reduce in block y.
int grid_num, reduce_num_per_thread;
if (reduce_lastdim) {
block_dim->x = detail::GetBlockDim(reduce_num);
block_dim->y = 1;
grid_num = left_num;
reduce_num_per_thread =
detail::AlignUp(reduce_num, block_dim->x * block_dim->y);
} else {
int block_x = detail::GetBlockDim(left_num);
int block_y = detail::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32);
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
block_dim->x =
std::min(block_x, static_cast<int>(max_num_threads / block_dim->y));
grid_num = detail::AlignUp(left_num, block_dim->x);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->y);
}
int device_id = platform::GetCurrentDeviceId();
int max_mp = platform::GetCUDAMultiProcessors(device_id);
int max_threads_per_mp =
platform::GetCUDAMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block_dim->x * block_dim->y;
int max_num_blocks = max_threads / num_threads;
// set grid size.
// Whether to set grid.y larger than 1, there are 3 following rules:
// 1. The number that each thread process should no less than
// min_reduce_num_per_threadbut no more than max_reduce_num_per_thread;
// 2. It should maximize the utilization of SM.
// So we choose the minimum between input_split_num_1 and input_split_num_3
// to make each thread process as mush data as possible. Meanwhile,
// the number cannot be larger than max_reduce_num_per_thread, so we
// choose the maximum between the result above and input_split_num_2.
int input_split_num_1 =
detail::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread);
int input_split_num_2 =
detail::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread);
int input_split_num_3 = detail::AlignUp(max_num_blocks, grid_num);
grid_dim->x = grid_num;
grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3),
input_split_num_2);
// if grid.y > 1, we need launch reduce kernel again.
if (grid_dim->y > 1) {
should_reduce_again = true;
}
}
// set block and grid for launch kernel // set block and grid for launch kernel
// for ReduceHigherDim: if block is enough -> splite reduce_num // for ReduceHigherDim: if block is enough -> splite reduce_num
// else init block(32, 1) grid(block_num, 1) // else init block(32, 1) grid(block_num, 1)
...@@ -368,6 +467,8 @@ struct ReduceConfig { ...@@ -368,6 +467,8 @@ struct ReduceConfig {
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
grid_dim.y = 1; grid_dim.y = 1;
} }
} else if (reduce_type == ReduceType::kReduceAny) {
SetBlockDimForReduceAny(&block_dim, &grid_dim);
} }
block = block_dim; block = block_dim;
...@@ -388,6 +489,7 @@ struct ReduceConfig { ...@@ -388,6 +489,7 @@ struct ReduceConfig {
int left_num; int left_num;
int blocking_size; int blocking_size;
bool should_reduce_again; bool should_reduce_again;
bool reduce_lastdim;
Ty* output_data; Ty* output_data;
...@@ -395,8 +497,12 @@ struct ReduceConfig { ...@@ -395,8 +497,12 @@ struct ReduceConfig {
dim3 grid; dim3 grid;
}; };
static __device__ int SharedMemoryIndex(int index) {
return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
template <typename T, typename ReduceOp> template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { static __device__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u; unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true); CREATE_SHFL_MASK(mask, true);
for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) { for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) {
...@@ -416,7 +522,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { ...@@ -416,7 +522,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
* res to warp0 and process the second WarpReduce * res to warp0 and process the second WarpReduce
*/ */
template <typename T, typename ReduceOp> template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) { static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize; using detail::kWarpSize;
__shared__ T shared[kWarpSize]; __shared__ T shared[kWarpSize];
int block_dim_x = blockDim.x; int block_dim_x = blockDim.x;
...@@ -441,12 +547,26 @@ __device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) { ...@@ -441,12 +547,26 @@ __device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
return val; return val;
} }
template <typename T, typename ReduceOp>
static __device__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[detail::kMaxThread];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
T temp = shared_memory[SharedMemoryIndex(stride)];
val = reducer(val, temp);
}
shared_memory[SharedMemoryIndex(0)] = val;
}
return val;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used // function will be used
// blockId.x -> left_num, threadId.x -> reduce_num // blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, __device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer,
ReduceOp reducer,
TransformOp transformer, Ty init, TransformOp transformer, Ty init,
int reduce_num) { int reduce_num) {
int idx_x = blockIdx.x * reduce_num; int idx_x = blockIdx.x * reduce_num;
...@@ -458,7 +578,7 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, ...@@ -458,7 +578,7 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
} }
__syncthreads(); __syncthreads();
reduce_var = BlockReduce(reduce_var, reducer); reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var; y[blockIdx.x] = reduce_var;
...@@ -471,11 +591,9 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, ...@@ -471,11 +591,9 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32 // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
ReduceOp reducer, TransformOp transformer, Ty init,
TransformOp transformer, int reduce_num, int left_num, int block_size) {
Ty init, int reduce_num,
int left_num, int block_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idy = blockIdx.y * block_size; int idy = blockIdx.y * block_size;
...@@ -497,71 +615,97 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, ...@@ -497,71 +615,97 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y,
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
// blockId.x -> left_num, threadId.x -> reduce_num template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
int Rank, int ReduceRank> TransformOp transformer, Ty init, int reduce_num,
__device__ __forceinline__ void ReduceAny( int left_num, bool reduce_lastdim,
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, const IndexCalculator& reduce_index_calculator,
int reduce_num, paddle::framework::Array<int, Rank> x_strides, const IndexCalculator& left_index_calculator) {
paddle::framework::Array<int, ReduceRank> reduce_dim, int input_idx, left_idx, stride;
paddle::framework::Array<int, ReduceRank> reduce_strides, // the last dim gets involved in reduction
paddle::framework::Array<int, Rank - ReduceRank> left_dim, if (reduce_lastdim) {
paddle::framework::Array<int, Rank - ReduceRank> left_strides) { input_idx = blockIdx.y * blockDim.x + threadIdx.x;
int sub_index[Rank]; left_idx = blockIdx.x;
int left_idx = blockIdx.x; stride = gridDim.y * blockDim.x;
for (int i = 0; i < Rank - ReduceRank; ++i) { } else {
sub_index[left_dim[i]] = left_idx / left_strides[i]; input_idx = blockIdx.y * blockDim.y + threadIdx.y;
left_idx %= left_strides[i]; left_idx = blockIdx.x * blockDim.x + threadIdx.x;
stride = gridDim.y * blockDim.y;
} }
// calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator.Get(left_idx);
const Tx* input = x + input_offset;
Ty reduce_var = init;
int reduce_idx = threadIdx.x; // 1. reduce for each thread
for (int j = 0; j < ReduceRank; ++j) { if (left_idx < left_num) {
sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; // load REDUCE_VEC_SIZE data once, and then compute
reduce_idx %= reduce_strides[j]; Tx input_reg[REDUCE_VEC_SIZE];
int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride;
while (input_idx < bound) {
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x];
} }
#pragma unroll
int idx_x = 0; for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
for (int k = 0; k < Rank; ++k) { reduce_var = reducer(reduce_var, transformer(input_reg[i]));
idx_x += (sub_index[k] * x_strides[k]);
} }
Ty reduce_var = static_cast<Ty>(transformer(x[idx_x])); input_idx += REDUCE_VEC_SIZE * stride;
for (int i = threadIdx.x + blockDim.x; i < reduce_num; i += blockDim.x) {
int reduce_idx = i;
for (int j = 0; j < ReduceRank; ++j) {
sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j];
reduce_idx %= reduce_strides[j];
} }
int idx_x = 0; // deal with the remain part
for (int k = 0; k < Rank; ++k) { int input_idx_tmp = input_idx;
idx_x += (sub_index[k] * x_strides[k]); #pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
if (input_idx >= reduce_num) {
break;
}
int reduce_idx = input_idx;
int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x];
input_idx += stride;
}
input_idx = input_idx_tmp;
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
if (input_idx >= reduce_num) {
break;
}
reduce_var = reducer(reduce_var, transformer(input_reg[i]));
input_idx += stride;
}
} }
reduce_var = static_cast<Ty>( // 2. reduce in block y
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x])))); if (blockDim.y > 1) {
reduce_var = BlockYReduce(reduce_var, reducer);
} }
__syncthreads(); __syncthreads();
reduce_var = BlockReduce(reduce_var, reducer); if (reduce_lastdim) {
// 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var; y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
}
} else {
if (left_idx < left_num && threadIdx.y == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var;
}
} }
} }
// module function designed for global function // module function designed for global function
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
int Rank, int ReduceRank> __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
__device__ __forceinline__ void ReduceModule( TransformOp transformer, Ty init, int reduce_num,
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int left_num, int blocking_size, int reduce_type,
int reduce_num, int left_num, int blocking_size, int reduce_type, bool reduce_lastdim,
paddle::framework::Array<int, Rank> x_strides, const IndexCalculator& reduce_index_calculator,
paddle::framework::Array<int, ReduceRank> reduce_dim, const IndexCalculator& left_index_calculator) {
paddle::framework::Array<int, ReduceRank> reduce_strides,
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
// reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1
if (reduce_type == ReduceType::kReduceLastDim) { if (reduce_type == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer, ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
init, reduce_num); init, reduce_num);
...@@ -573,104 +717,66 @@ __device__ __forceinline__ void ReduceModule( ...@@ -573,104 +717,66 @@ __device__ __forceinline__ void ReduceModule(
// reduce_rank >= 2 // reduce_rank >= 2
} else { } else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>( ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, reduce_num, x_strides, reduce_dim, x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
reduce_strides, left_dim, left_strides); reduce_index_calculator, left_index_calculator);
} }
} }
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
int Rank, int ReduceRank> __global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
__global__ void ReduceKernelFunction( TransformOp transformer, Ty init,
const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int reduce_num, int left_num,
int reduce_num, int left_num, int block_size, int reduce_type, int blocking_size, int reduce_type,
paddle::framework::Array<int, Rank> x_strides, bool reduce_lastdim,
paddle::framework::Array<int, ReduceRank> reduce_dim, IndexCalculator reduce_index_calculator,
paddle::framework::Array<int, ReduceRank> reduce_strides, IndexCalculator left_index_calculator) {
paddle::framework::Array<int, Rank - ReduceRank> left_dim, ReduceModule<Tx, Ty, ReduceOp, TransformOp>(
paddle::framework::Array<int, Rank - ReduceRank> left_strides) { x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
ReduceModule<Tx, Ty, ReduceOp, TransformOp, Rank, ReduceRank>( reduce_type, reduce_lastdim, reduce_index_calculator,
x, y, reducer, transformer, init, reduce_num, left_num, block_size, left_index_calculator);
reduce_type, x_strides, reduce_dim, reduce_strides, left_dim,
left_strides);
} }
template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank> template <typename Tx, typename Ty, typename ReduceOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init, const ReduceOp& reducer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) { gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer; using TransformOp = typename ReduceOp::Transformer;
ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, Rank, int reduce_rank = config.reduce_strides.size();
ReduceRank><<<config.grid, config.block, 0, stream>>>( int left_rank = config.left_strides.size();
auto reduce_index_calculator = IndexCalculator(
reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides);
auto left_index_calculator = IndexCalculator(
left_rank, config.left_dim, config.left_strides, config.x_strides);
ReduceKernelFunction<Tx, Ty, ReduceOp,
TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
config.reduce_num, config.left_num, config.blocking_size, config.reduce_num, config.left_num, config.blocking_size,
config.reduce_type, detail::VectorToArray<int, Rank>(config.x_strides), config.reduce_type, config.reduce_lastdim, reduce_index_calculator,
detail::VectorToArray<int, ReduceRank>(config.reduce_dim), left_index_calculator);
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
if (config.should_reduce_again) { if (config.should_reduce_again) {
dim3 block(config.block.x, 1, 1); dim3 block;
dim3 grid(config.grid.x, 1, config.grid.z); dim3 grid;
if (config.reduce_lastdim) {
block = dim3(32, 1, 1);
grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1);
} else {
block = dim3(config.block.x, 1, 1);
grid = dim3(config.grid.x, 1, config.grid.z);
}
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank, ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<
ReduceRank><<<grid, block, 0, stream>>>( Ty>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y, detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, ReduceType::kReduceHigherDim, config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
detail::VectorToArray<int, Rank>(config.x_strides), config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
} }
} }
template <typename Tx, typename Ty, typename ReduceOp>
static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init,
gpuStream_t stream, ReduceConfig<Ty> config) {
int reduce_rank = config.reduce_strides.size();
int rank = config.x_strides.size();
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto Rank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto ReduceRank = i; \
LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
x_data, y_data, reducer, init, stream, config); \
} break
detail::CheckReduceRank(reduce_rank, rank);
switch (rank) {
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
}
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
template <typename Tx, typename Ty, template <typename Tx, typename Ty,
template <typename, typename> class ReduceOp> template <typename, typename> class ReduceOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
...@@ -718,8 +824,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -718,8 +824,8 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
return; return;
} }
ReduceKernelImpl<Tx, Ty, ReduceOp<Tx, Ty>>(x_data, y_data, reducer, LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
reducer.initial(), stream, config); x_data, y_data, reducer, reducer.initial(), stream, config);
} }
template <typename Tx, template <typename, typename> class ReduceOp> template <typename Tx, template <typename, typename> class ReduceOp>
......
...@@ -54,7 +54,7 @@ struct FastDivMod { ...@@ -54,7 +54,7 @@ struct FastDivMod {
return (t + n) >> shift_val; return (t + n) >> shift_val;
} }
__device__ __forceinline__ DivModT Divmod(uint32_t n) { __device__ __forceinline__ DivModT Divmod(uint32_t n) const {
uint32_t q = Div(n); uint32_t q = Div(n);
DivModT result = {q, n - q * divisor}; DivModT result = {q, n - q * divisor};
return result; return result;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册