From f17a73e9de8b0ab953a73fc77f41a0b04fe14072 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:55:09 +0800 Subject: [PATCH] Optimize performance of depthwise_conv_bwd (#46362) * Optimize performance of depthwise_conv_bwd * fix --- paddle/phi/kernels/gpu/depthwise_conv.h | 107 +++++++++++++----------- 1 file changed, 60 insertions(+), 47 deletions(-) diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index 7e8ca384bc2..62a8df72c28 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -469,60 +469,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { const int dilate_height, const int dilate_width, \ T *const input_grad_data -template +template __device__ __inline__ void KernelDepthwiseConvInputGradNCHW( ARG_DEFINE_KernelDepthwiseConvInputGrad) { - const int batch = blockIdx.y; - const int c_in = blockIdx.x; - for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { - for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { - const int c_out_start = c_in * filter_multiplier; - int h_out_start = - h_in - (filter_height - 1) * dilate_height + padding_height; - int h_out_end = h_in + padding_height; - int w_out_start = - w_in - (filter_width - 1) * dilate_width + padding_width; - int w_out_end = w_in + padding_width; + const int fw_size = c_filter != -1 ? c_filter : filter_width; + const int fh_size = c_filter != -1 ? c_filter : filter_height; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * input_channels * input_height * input_width) { + return; + } + if (fuse_relu_before_conv) { + if (input_data[idx] <= static_cast(0.0f)) { + input_grad_data[idx] = 0; + return; + } + } - T value(0); - int index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; + int tmp_1 = idx / input_width; + const int w_in = idx - tmp_1 * input_width; + int tmp_2 = tmp_1 / input_height; + const int h_in = tmp_1 - tmp_2 * input_height; + tmp_1 = tmp_2; + tmp_2 = tmp_1 / input_channels; + const int c_in = tmp_1 - tmp_2 * input_channels; + const int batch = tmp_2; - if (fuse_relu_before_conv) { - if (input_data[index] <= T(0)) { - input_grad_data[index] = 0; - continue; - } - } + T value(0); + for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) { + int c_out = c_in * filter_multiplier + c_mul; + int filter_offset = c_out * filter_height * filter_width; - for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; - c_out++) { - int filter_offset = (c_out + 1) * filter_height * filter_width; - for (int h_out = h_out_start; h_out <= h_out_end; - h_out += dilate_height) { - for (int w_out = w_out_start; w_out <= w_out_end; - w_out += dilate_width) { - filter_offset--; - int s_h_out = h_out / stride_height; - int s_w_out = w_out / stride_width; - if (h_out % stride_height == 0 && w_out % stride_width == 0 && - s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && - s_w_out < output_width) { - int output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; - value += output_grad_data[output_grad_offset] * - filter_data[filter_offset]; - } +#pragma unroll + for (int fh = 0; fh < fh_size; ++fh) { +#pragma unroll + for (int fw = 0; fw < fw_size; ++fw) { + int h_out = h_in + padding_height - fh * dilate_height; + int w_out = w_in + padding_width - fw * dilate_width; + if ((h_out - h_out / stride_height * stride_height == 0) && + (w_out - w_out / stride_width * stride_width == 0)) { + h_out /= stride_height; + w_out /= stride_width; + + if (h_out >= 0 && h_out < output_height && w_out >= 0 && + w_out < output_width) { + int output_grad_offset = + ((batch * output_channels + c_out) * output_height + h_out) * + output_width + + w_out; + value += output_grad_data[output_grad_offset] * + filter_data[filter_offset]; } } + filter_offset++; } - input_grad_data[index] = value; } } + input_grad_data[idx] = value; } template @@ -735,7 +737,7 @@ __global__ void KernelDepthwiseConvInputGradSp( if (c_filter_multiplier == 0 || c_filter == -1) { if (data_layout != DataLayout::kNHWC) { - KernelDepthwiseConvInputGradNCHW( + KernelDepthwiseConvInputGradNCHW( input_data, output_grad_data, filter_data, @@ -1247,8 +1249,7 @@ class DepthwiseConvFunctor { batch_size); } int filter_multiplier = output_channels / input_channels; - int nums_output = - batch_size * output_channels * output_height * output_width; + int nums_output = output->numel(); #ifdef __HIPCC__ int block_size = 256; #else @@ -1421,6 +1422,13 @@ class DepthwiseConvInputGradFunctor { batch_size); } int filter_multiplier = output_channels / input_channels; + int nums_input = input_grad->numel(); +#ifdef __HIPCC__ + int block_size = 256; +#else + int block_size = 512; +#endif + int grid_size = (nums_input + block_size - 1) / block_size; #define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ @@ -1429,6 +1437,11 @@ class DepthwiseConvInputGradFunctor { (ksize_height == ksize_width && ksize_height == c_filter || \ c_filter == -1)) { \ if (data_layout != DataLayout::kNHWC) { \ + if (c_filter == -1) { \ + threads.x = block_size; \ + grid.x = grid_size; \ + threads.y = threads.z = grid.y = grid.z = 1; \ + } \ KernelDepthwiseConvInputGradSp