diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 2d9a7522515d0109a4bb4f268e48ab2f6285154b..4f3c069f3b249161eb83698c4ded150b8f003b14 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -110,7 +110,11 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); val = reducer(val, temp); } - return val; + if (threadIdx.x == 0) { + shared[threadIdx.y] = val; + } + __syncthreads(); + return shared[threadIdx.y]; } /**