From 0e563da6298b5f6397fbe3ab235515e5642cb0a1 Mon Sep 17 00:00:00 2001 From: 5u13 <39851894+5u13@users.noreply.github.com> Date: Tue, 20 Sep 2022 21:15:32 +0800 Subject: [PATCH] optimization of max_pool3d grad (#45934) --- paddle/phi/kernels/funcs/pooling.cu | 145 ++++++++++++---------------- 1 file changed, 61 insertions(+), 84 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index b7b5dbd5b0d..875fa92002a 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -2319,87 +2319,52 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, } template -__global__ void KernelMaxPool3DWithIdxGrad(const int nthreads, - const T1* output_grad, - const T2* mask, - const int channels, - const int input_depth, - const int input_height, - const int input_width, - const int output_depth, - const int output_height, - const int output_width, - const int ksize_depth, - const int ksize_height, - const int ksize_width, - const int stride_depth, - const int stride_height, - const int stride_width, - const int padding_depth, - const int padding_height, - const int padding_width, - bool adaptive, - T1* input_grad) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int w_offset = index % input_width; - int h_offset = (index / input_width) % input_height; - int d_offset = (index / input_width / input_height) % input_depth; - int c_offset = - (index / input_width / input_height / input_depth) % channels; - int batch_idx = index / input_width / input_height / input_depth / channels; - - int pdstart, pdend; - int phstart, phend; - int pwstart, pwend; - if (adaptive) { - pdstart = d_offset * output_depth / input_depth; - pdend = - min((d_offset + 1) * output_depth / input_depth + 1, output_depth); - phstart = h_offset * output_height / input_height; - phend = - min((h_offset + 1) * output_height / input_height + 1, output_height); - pwstart = w_offset * output_width / input_width; - pwend = - min((w_offset + 1) * output_width / input_width + 1, output_width); - } else { - pdstart = - (d_offset + padding_depth < ksize_depth) - ? 0 - : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; - phstart = - (h_offset + padding_height < ksize_height) - ? 0 - : (h_offset + padding_height - ksize_height) / stride_height + 1; - pwstart = - (w_offset + padding_width < ksize_width) - ? 0 - : (w_offset + padding_width - ksize_width) / stride_width + 1; - pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth); - phend = - min((h_offset + padding_height) / stride_height + 1, output_height); - pwend = min((w_offset + padding_width) / stride_width + 1, output_width); - } +__global__ void KernelMaxPool3DWithIdxGrad( + const int ncd, + const T1* output_grad, + const T2* mask, + const int channels, + const int input_depth, + const int input_height, + const int input_width, + const int output_depth, + const int output_height, + const int output_width, + const int ksize_depth, + const int ksize_height, + const int ksize_width, + const int stride_depth, + const int stride_height, + const int stride_width, + const int padding_depth, + const int padding_height, + const int padding_width, + bool adaptive, + T1* input_grad, + FastDivModForPooling3D divmods_output) { + int w_offset, h_offset, d_offset, nc_offset; - T1 input_grad_data = 0; - int input_current_feature_map_idx = - (d_offset * input_height + h_offset) * input_width + w_offset; - int output_idx = (batch_idx * channels + c_offset) * output_depth * - output_height * output_width; - mask += output_idx; - output_grad += output_idx; + w_offset = blockIdx.x * blockDim.x + threadIdx.x; + h_offset = blockIdx.y * blockDim.y + threadIdx.y; - for (int pd = pdstart; pd < pdend; ++pd) { - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if (mask[(pd * output_height + ph) * output_width + pw] == - input_current_feature_map_idx) - input_grad_data += - output_grad[(pd * output_height + ph) * output_width + pw]; - } + if (w_offset < output_width && h_offset < output_height) { + for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd; + index_z += gridDim.z * blockDim.z) { + auto output_depth_divmod = divmods_output.depth.Divmod(index_z); + d_offset = output_depth_divmod.val[1]; + nc_offset = output_depth_divmod.val[0]; + int output_index = + nc_offset * output_depth * output_height * output_width + + d_offset * output_height * output_width + h_offset * output_width + + w_offset; + int max_index = mask[output_index]; + if (max_index != -1) { + paddle::platform::CudaAtomicAdd( + &input_grad[nc_offset * input_depth * input_height * input_width + + max_index], + output_grad[output_index]); } } - input_grad[index] = input_grad_data; } } @@ -2523,14 +2488,25 @@ class MaxPool3dWithIndexGradFunctor { const T2* mask_data = mask.data(); T1* input_grad_data = context.template Alloc(input_grad); - int nthreads = - batch_size * input_channels * input_depth * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); + int ncd = batch_size * input_channels * output_depth; + + int thread_x = 32; + int thread_y = 8; + int thread_z = 1; + dim3 threads(thread_x, thread_y, thread_z); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + int block_x = (output_width + threads.x - 1) / threads.x; + int block_y = (output_height + threads.y - 1) / threads.y; + int block_z = (ncd > max_grid_dim[2] * threads.z) + ? max_grid_dim[2] + : (ncd + threads.z - 1) / threads.z; + dim3 grid(block_x, block_y, block_z); + + auto pool_divmods_output = FastDivModForPooling3D( + input_channels, output_width, output_height, output_depth); KernelMaxPool3DWithIdxGrad - <<>>(nthreads, + <<>>(ncd, output_grad_data, mask_data, input_channels, @@ -2550,7 +2526,8 @@ class MaxPool3dWithIndexGradFunctor { padding_height, padding_width, adaptive, - input_grad_data); + input_grad_data, + pool_divmods_output); } }; -- GitLab