diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index 62a8df72c281f37ccd2b78a9052a73b7cdd0d0e8..12c988f1ecfc88c4ddd530897b17d8d2f2bb9605 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -87,43 +87,36 @@ class DepthwiseConvFilterGradFunctor { const DataLayout data_layout = DataLayout::kNCHW); }; +#define FINAL_MASK 0xffffffff +#define HALF_WARP 16 +#define WARP_SIZE 32 + template -static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) { - typedef cub::WarpReduce WarpReduce; - typename WarpReduce::TempStorage temp_storage; - val = WarpReduce(temp_storage).Sum(val, warp_size); +__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val += platform::CudaShuffleDownSync(lane_mask, val, mask); return val; } template -__forceinline__ __device__ T BlockReduceSum(T val) { - static __shared__ T shared[32]; - int thread_id = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * blockDim.x * blockDim.y; - int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); - int lane = thread_id % warp_size; - int wid = thread_id / warp_size; - - val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction - - if (lane == 0) shared[wid] = val; // Write reduced value to shared memory - __syncthreads(); // Wait for all partial reductions - - // read from shared memory only if that warp existed - int block_size = blockDim.x * blockDim.y * blockDim.z; - if (thread_id < (block_size - 1) / warp_size + 1) { - val = shared[lane]; - } else { - val = static_cast(0); - } +__forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) { + static __shared__ T shared[WARP_SIZE]; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int lane = tid & 0x1f; + int wid = tid >> 5; + + val = WarpReduceSum(val, mask); - if (wid == 0) { - val = WarpReduceSum(val, warp_size); // Final reduce within first warp - } __syncthreads(); - if (thread_id != 0) { - val = static_cast(0); - } + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + // align block_span to WARP_SIZE + int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5; + val = (lane < block_span) ? shared[lane] : static_cast(0.0f); + val = WarpReduceSum(val, mask); + return val; } @@ -858,45 +851,81 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( const int dilate_height, const int dilate_width, T* filter_grad_data) { - T s(0); - int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; - - for (int image_w = threadIdx.x; image_w < output_width; - image_w += blockDim.x) { - for (int bid = 0; bid < num; bid++) { - for (int image_h = threadIdx.y; image_h < output_height; - image_h += blockDim.y) { - int kernel_id = blockIdx.z; - int kernel_h = blockIdx.y * dilate_height - padding_height; - int kernel_w = blockIdx.x * dilate_width - padding_width; - - int image_hk = image_h * stride_height + kernel_h; - int image_wk = image_w * stride_width + kernel_w; - if (image_hk < 0 || image_hk >= input_height) continue; - if (image_wk < 0 || image_wk >= input_width) continue; -#define gaid(N, C, H, W) \ - ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) - int input_id = ((bid * (gridDim.z / filter_multiplier) + - kernel_id / filter_multiplier) * - input_height + - image_hk) * - input_width + - image_wk; + T f_grad(0); + const bool loop_batch = output_height * output_width >= WARP_SIZE; + + int kw_id = blockIdx.x; + int kh_id = blockIdx.y; + int oc_id = blockIdx.z; + int ic_id = oc_id / filter_multiplier; + int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + + const int ohw = output_height * output_width; + const int onhw = num * ohw; + const int h_offset = kh_id * dilate_height - padding_height; + const int w_offset = kw_id * dilate_width - padding_width; + + if (loop_batch) { + for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) { + for (int bid = 0; bid < num; ++bid) { + for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) { + int i_h = og_h * stride_height + h_offset; + int i_w = og_w * stride_width + w_offset; + + if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) { + int input_offset = + ((bid * input_channels + ic_id) * input_height + i_h) * + input_width + + i_w; + int output_grad_offset = + ((bid * output_channels + oc_id) * output_height + og_h) * + output_width + + og_w; + if (fuse_relu_before_conv) { + f_grad += + output_grad_data[output_grad_offset] * + static_cast( + max(0.0f, static_cast(input_data[input_offset]))); + } else { + f_grad += output_grad_data[output_grad_offset] * + input_data[input_offset]; + } + } + } + } + } + } else { + for (int id = threadIdx.x; id < onhw; id += blockDim.x) { + int bid = id / ohw; + int og_hw = id - bid * ohw; + int og_h = og_hw / output_width; + int og_w = og_hw - og_h * output_width; + + int i_h = og_h * stride_height + h_offset; + int i_w = og_w * stride_width + w_offset; + + if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) { + int input_offset = + ((bid * input_channels + ic_id) * input_height + i_h) * + input_width + + i_w; + int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw; if (fuse_relu_before_conv) { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - static_cast( - max(0.0f, static_cast(input_data[input_id]))); + f_grad += output_grad_data[output_grad_offset] * + static_cast(max( + 0.0f, static_cast(input_data[input_offset]))); } else { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - input_data[input_id]; + f_grad += + output_grad_data[output_grad_offset] * input_data[input_offset]; } -#undef gaid } } } - T val = BlockReduceSum(s); - if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val; + T val = BlockReduceSum(f_grad); + if (threadIdx.x == 0 && threadIdx.y == 0) { + filter_grad_data[idx] = val; + } } template @@ -1572,6 +1601,10 @@ class DepthwiseConvFilterGradFunctor