diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 7c5f59fab0d280587e15b6e1353c9dd2bda9270a..6da1bfb964f24ff94aa3137053cdcab8a57726de 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -31,18 +31,43 @@ namespace operators { namespace math { template -__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { +static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) { typedef cub::WarpReduce WarpReduce; typename WarpReduce::TempStorage temp_storage; + val = WarpReduce(temp_storage).Sum(val, warp_size); + return val; +} -#ifdef __HIPCC__ - int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); - value = WarpReduce(temp_storage).Sum(value, block_size); -#else - value = WarpReduce(temp_storage).Sum(value); -#endif +template +__forceinline__ __device__ T BlockReduceSum(T val) { + static __shared__ T shared[32]; + int thread_id = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; + int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); + int lane = thread_id % warp_size; + int wid = thread_id / warp_size; + + val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction + + if (lane == 0) shared[wid] = val; // Write reduced value to shared memory + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + int block_size = blockDim.x * blockDim.y * blockDim.z; + if (thread_id < (block_size - 1) / warp_size + 1) { + val = shared[lane]; + } else { + val = static_cast(0); + } - if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); + if (wid == 0) { + val = WarpReduceSum(val, warp_size); // Final reduce within first warp + } + __syncthreads(); + if (thread_id != 0) { + val = static_cast(0); + } + return val; } #define ARG_DEFINE_KernelDepthwiseConv \ @@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( } } } - CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); + + T val = BlockReduceSum(s); + platform::CudaAtomicAdd(&filter_grad_data[gbid], val); } template @@ -892,6 +919,7 @@ class DepthwiseConvFunctor 1024 && output_width <= 2048) thread = (output_width - 1) / 2 + 1; @@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor 1024 && input_width <= 2048) { thread = (input_width - 1) / 2 + 1;