未验证 提交 c7cc5ac2 编写于 作者: Z Zhang Zheng 提交者: GitHub

Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny (#34436)

上级 80f7f7ea
...@@ -360,19 +360,26 @@ struct ReduceConfig { ...@@ -360,19 +360,26 @@ struct ReduceConfig {
constexpr int max_num_threads = detail::kMaxThread; constexpr int max_num_threads = detail::kMaxThread;
// set block size. // set block size.
// 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y; // 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
// 2. if reduce_lastdim == false, block is 2-D, if it is necessary, // will process the reduction for one output.
// it should reduce in block y. // The number of output for one block is blockDim.y;
// 2. If reduce_lastdim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is
// necessary, it should reduce in block y.
// The number of output for one block is blockDim.x;
int block_x, block_y;
int grid_num, reduce_num_per_thread; int grid_num, reduce_num_per_thread;
if (reduce_lastdim) { if (reduce_lastdim) {
block_dim->x = detail::GetBlockDim(reduce_num); block_x = detail::GetBlockDim(reduce_num);
block_dim->y = 1; block_y = detail::GetBlockDim(left_num);
grid_num = left_num; block_dim->x = block_x;
reduce_num_per_thread = block_dim->y =
detail::AlignUp(reduce_num, block_dim->x * block_dim->y); std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
grid_num = detail::AlignUp(left_num, block_dim->y);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x);
} else { } else {
int block_x = detail::GetBlockDim(left_num); block_x = detail::GetBlockDim(left_num);
int block_y = detail::GetBlockDim(reduce_num); block_y = detail::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32); block_dim->x = std::min(block_x, 32);
block_dim->y = block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x)); std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
...@@ -467,7 +474,7 @@ struct ReduceConfig { ...@@ -467,7 +474,7 @@ struct ReduceConfig {
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
grid_dim.y = 1; grid_dim.y = 1;
} }
} else if (reduce_type == ReduceType::kReduceAny) { } else {
SetBlockDimForReduceAny(&block_dim, &grid_dim); SetBlockDimForReduceAny(&block_dim, &grid_dim);
} }
...@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) { ...@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
template <typename T, typename ReduceOp> template <typename T, typename ReduceOp>
static __device__ T BlockXReduce(T val, ReduceOp reducer) { static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize; using detail::kWarpSize;
__shared__ T shared[kWarpSize]; __shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x; int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) { if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize; block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize; int lane = threadIdx.x % kWarpSize;
int wid = threadIdx.x / kWarpSize; int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer); val = WarpReduce(val, reducer);
if (lane == 0) { if (lane == 0) {
shared[wid] = val; shared[wid] = val;
} }
__syncthreads(); __syncthreads();
val = shared[lane]; val = shared[bid * block_dim_x + lane];
} }
unsigned mask = 0u; unsigned mask = 0u;
...@@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
return val; return val;
} }
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
int reduce_num) {
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
}
__syncthreads();
reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
}
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used // function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
...@@ -613,19 +599,21 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -613,19 +599,21 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
} }
} }
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
typename ReduceIndexCal, typename LeftIndexCal>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num, TransformOp transformer, Ty init, int reduce_num,
int left_num, bool reduce_lastdim, int left_num, bool reduce_lastdim,
const IndexCalculator& reduce_index_calculator, ReduceIndexCal reduce_index_calculator,
const IndexCalculator& left_index_calculator) { LeftIndexCal left_index_calculator) {
int input_idx, left_idx, stride; int input_idx, left_idx, stride;
// the last dim gets involved in reduction // the last dim gets involved in reduction
if (reduce_lastdim) { if (reduce_lastdim) {
input_idx = blockIdx.y * blockDim.x + threadIdx.x; input_idx = blockIdx.y * blockDim.x + threadIdx.x;
left_idx = blockIdx.x; left_idx = blockIdx.x * blockDim.y + threadIdx.y;
stride = gridDim.y * blockDim.x; stride = gridDim.y * blockDim.x;
} else { } else {
input_idx = blockIdx.y * blockDim.y + threadIdx.y; input_idx = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -633,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -633,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
stride = gridDim.y * blockDim.y; stride = gridDim.y * blockDim.y;
} }
// calculate the offset, means the addr where each thread really start. // calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator.Get(left_idx); int input_offset = left_index_calculator(left_idx);
const Tx* input = x + input_offset; const Tx* input = x + input_offset;
Ty reduce_var = init; Ty reduce_var = init;
...@@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
#pragma unroll #pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride; int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator.Get(reduce_idx); int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x]; input_reg[i] = input[idx_x];
} }
#pragma unroll #pragma unroll
...@@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
break; break;
} }
int reduce_idx = input_idx; int reduce_idx = input_idx;
int idx_x = reduce_index_calculator.Get(reduce_idx); int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x]; input_reg[i] = input[idx_x];
input_idx += stride; input_idx += stride;
} }
...@@ -680,7 +668,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -680,7 +668,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
} }
// 2. reduce in block y // 2. reduce in block y
if (blockDim.y > 1) { if (!reduce_lastdim && blockDim.y > 1) {
reduce_var = BlockYReduce(reduce_var, reducer); reduce_var = BlockYReduce(reduce_var, reducer);
} }
__syncthreads(); __syncthreads();
...@@ -688,8 +676,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -688,8 +676,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
if (reduce_lastdim) { if (reduce_lastdim) {
// 3. reduce in block x // 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer); reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) { if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var; y[blockIdx.y * left_num + left_idx] = reduce_var;
} }
} else { } else {
if (left_idx < left_num && threadIdx.y == 0) { if (left_idx < left_num && threadIdx.y == 0) {
...@@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
const IndexCalculator& reduce_index_calculator, const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) { const IndexCalculator& left_index_calculator) {
if (reduce_type == ReduceType::kReduceLastDim) { if (reduce_type == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer, ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
init, reduce_num); x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return idx; },
[&](int idx) { return idx * reduce_num; });
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) { } else if (reduce_type == ReduceType::kReduceHigherDim) {
...@@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
} else { } else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>( ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
reduce_index_calculator, left_index_calculator); [&](int idx) { return reduce_index_calculator.Get(idx); },
[&](int idx) { return left_index_calculator.Get(idx); });
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册