From ae4d1ec1565bca9a182124e125cbb178788792ef Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 9 May 2022 15:09:08 +0800 Subject: [PATCH] Modified reduce for xpu2 (#42439) --- paddle/phi/kernels/funcs/reduce_function.h | 6 +++++- paddle/phi/kernels/primitive/compute_primitives_xpu2.h | 8 +++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 42fee144883..df14b0a21f2 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -473,7 +473,11 @@ struct ReduceConfig { bool not_higher = x_dim[0] >= max_grid_z; #endif if (reduce_last_dim && (reduce_rank == 1)) { +#ifdef PADDLE_WITH_XPU_KP + reduce_type = static_cast(ReduceType::kReduceAny); +#else reduce_type = static_cast(ReduceType::kReduceLastDim); +#endif } else if (reduce_rank == 1) { reduce_type = static_cast(ReduceType::kReduceHigherDim); if (rank == 3 && not_higher) { @@ -588,7 +592,7 @@ struct ReduceConfig { void SetBlockDim() { // init should_reduce_again = false; - dim3 block_dim; + dim3 block_dim(1, 1, 1); dim3 grid_dim(left_num, 1, 1); blocking_size = reduce_num; diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index 4d65dd6dd5d..0e77b11988e 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -329,14 +329,12 @@ __device__ __forceinline__ void Reduce(T* out, ReduceFunctor reducer, bool reduce_last_dim) { if (Mode == details::kGlobalMode) { + if (reduce_last_dim) { #pragma unroll - for (int i = 0; i < NY; ++i) { -#pragma unroll - for (int j = 0; j < NX; ++j) { - out[i] = reducer(out[i], in[i * NX + j]); + for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x + details::BlockXReduce(&out[i], reducer); } } - details::BlockXReduce(out, reducer); } else { // else kLocalMode #pragma unroll for (int i = 0; i < NY; ++i) { -- GitLab