diff --git a/paddle/pten/kernels/primitive/compute_primitives.h b/paddle/pten/kernels/primitive/compute_primitives.h index 449c81b915e6d6de2eecd033fbcd374c5c426292..a8ed0816227635e3f0159bf00c3c3c6243c7daa9 100644 --- a/paddle/pten/kernels/primitive/compute_primitives.h +++ b/paddle/pten/kernels/primitive/compute_primitives.h @@ -118,7 +118,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { */ template __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { - __shared__ T shared_memory[details::kReduceMaxThread]; + __shared__ T shared_memory[1024]; shared_memory[SharedMemoryIndex(0)] = val; for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { __syncthreads(); @@ -128,7 +128,8 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { } shared_memory[SharedMemoryIndex(0)] = val; } - return val; + __syncthreads(); + return shared_memory[threadIdx.x]; } } // namespace details