未验证 提交 c7cc5ac2 编写于 作者: Z Zhang Zheng 提交者: GitHub

Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny (#34436)

上级 80f7f7ea
......@@ -360,19 +360,26 @@ struct ReduceConfig {
constexpr int max_num_threads = detail::kMaxThread;
// set block size.
// 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y;
// 2. if reduce_lastdim == false, block is 2-D, if it is necessary,
// it should reduce in block y.
// 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
// will process the reduction for one output.
// The number of output for one block is blockDim.y;
// 2. If reduce_lastdim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is
// necessary, it should reduce in block y.
// The number of output for one block is blockDim.x;
int block_x, block_y;
int grid_num, reduce_num_per_thread;
if (reduce_lastdim) {
block_dim->x = detail::GetBlockDim(reduce_num);
block_dim->y = 1;
grid_num = left_num;
reduce_num_per_thread =
detail::AlignUp(reduce_num, block_dim->x * block_dim->y);
block_x = detail::GetBlockDim(reduce_num);
block_y = detail::GetBlockDim(left_num);
block_dim->x = block_x;
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
grid_num = detail::AlignUp(left_num, block_dim->y);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x);
} else {
int block_x = detail::GetBlockDim(left_num);
int block_y = detail::GetBlockDim(reduce_num);
block_x = detail::GetBlockDim(left_num);
block_y = detail::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32);
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
......@@ -467,7 +474,7 @@ struct ReduceConfig {
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
grid_dim.y = 1;
}
} else if (reduce_type == ReduceType::kReduceAny) {
} else {
SetBlockDimForReduceAny(&block_dim, &grid_dim);
}
......@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
template <typename T, typename ReduceOp>
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize;
__shared__ T shared[kWarpSize];
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int wid = threadIdx.x / kWarpSize;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[lane];
val = shared[bid * block_dim_x + lane];
}
unsigned mask = 0u;
......@@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
return val;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
int reduce_num) {
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
}
__syncthreads();
reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
}
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
......@@ -613,19 +599,21 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
}
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
typename ReduceIndexCal, typename LeftIndexCal>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
int left_num, bool reduce_lastdim,
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
ReduceIndexCal reduce_index_calculator,
LeftIndexCal left_index_calculator) {
int input_idx, left_idx, stride;
// the last dim gets involved in reduction
if (reduce_lastdim) {
input_idx = blockIdx.y * blockDim.x + threadIdx.x;
left_idx = blockIdx.x;
left_idx = blockIdx.x * blockDim.y + threadIdx.y;
stride = gridDim.y * blockDim.x;
} else {
input_idx = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -633,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
stride = gridDim.y * blockDim.y;
}
// calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator.Get(left_idx);
int input_offset = left_index_calculator(left_idx);
const Tx* input = x + input_offset;
Ty reduce_var = init;
......@@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator.Get(reduce_idx);
int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x];
}
#pragma unroll
......@@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
break;
}
int reduce_idx = input_idx;
int idx_x = reduce_index_calculator.Get(reduce_idx);
int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x];
input_idx += stride;
}
......@@ -680,7 +668,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
}
// 2. reduce in block y
if (blockDim.y > 1) {
if (!reduce_lastdim && blockDim.y > 1) {
reduce_var = BlockYReduce(reduce_var, reducer);
}
__syncthreads();
......@@ -688,8 +676,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
if (reduce_lastdim) {
// 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var;
}
} else {
if (left_idx < left_num && threadIdx.y == 0) {
......@@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
if (reduce_type == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
init, reduce_num);
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return idx; },
[&](int idx) { return idx * reduce_num; });
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) {
......@@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
reduce_index_calculator, left_index_calculator);
[&](int idx) { return reduce_index_calculator.Get(idx); },
[&](int idx) { return left_index_calculator.Get(idx); });
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册