未验证 提交 ae4d1ec1 编写于 作者: N niuliling123 提交者: GitHub

Modified reduce for xpu2 (#42439)

上级 8b546f1c
...@@ -473,7 +473,11 @@ struct ReduceConfig { ...@@ -473,7 +473,11 @@ struct ReduceConfig {
bool not_higher = x_dim[0] >= max_grid_z; bool not_higher = x_dim[0] >= max_grid_z;
#endif #endif
if (reduce_last_dim && (reduce_rank == 1)) { if (reduce_last_dim && (reduce_rank == 1)) {
#ifdef PADDLE_WITH_XPU_KP
reduce_type = static_cast<int>(ReduceType::kReduceAny);
#else
reduce_type = static_cast<int>(ReduceType::kReduceLastDim); reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
#endif
} else if (reduce_rank == 1) { } else if (reduce_rank == 1) {
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim); reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
if (rank == 3 && not_higher) { if (rank == 3 && not_higher) {
...@@ -588,7 +592,7 @@ struct ReduceConfig { ...@@ -588,7 +592,7 @@ struct ReduceConfig {
void SetBlockDim() { void SetBlockDim() {
// init // init
should_reduce_again = false; should_reduce_again = false;
dim3 block_dim; dim3 block_dim(1, 1, 1);
dim3 grid_dim(left_num, 1, 1); dim3 grid_dim(left_num, 1, 1);
blocking_size = reduce_num; blocking_size = reduce_num;
......
...@@ -329,14 +329,12 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -329,14 +329,12 @@ __device__ __forceinline__ void Reduce(T* out,
ReduceFunctor reducer, ReduceFunctor reducer,
bool reduce_last_dim) { bool reduce_last_dim) {
if (Mode == details::kGlobalMode) { if (Mode == details::kGlobalMode) {
if (reduce_last_dim) {
#pragma unroll #pragma unroll
for (int i = 0; i < NY; ++i) { for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
#pragma unroll details::BlockXReduce<T, ReduceFunctor, 1>(&out[i], reducer);
for (int j = 0; j < NX; ++j) {
out[i] = reducer(out[i], in[i * NX + j]);
} }
} }
details::BlockXReduce<T, ReduceFunctor, NY>(out, reducer);
} else { // else kLocalMode } else { // else kLocalMode
#pragma unroll #pragma unroll
for (int i = 0; i < NY; ++i) { for (int i = 0; i < NY; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册