From 6d06786078cdc8469bd15cff80994051dc7bf718 Mon Sep 17 00:00:00 2001 From: Ouyang Chao Date: Tue, 20 Sep 2022 21:13:13 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PFCC=E7=AE=97=E5=AD=90=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E4=BC=98=E5=8C=96=E3=80=91=E4=B8=BAPaddle=E4=BC=98?= =?UTF-8?q?=E5=8C=96adaptive=5Fpooling=5Fop=E6=80=A7=E8=83=BD=20(#45959)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimize adaptive_pooling_op (forward) * fix bug of AdaptiveKernelMaxPool2dWithIdx * fix bug of AdaptiveKernelPool2D --- paddle/phi/kernels/funcs/pooling.cu | 465 ++++++++++++++++++++-------- paddle/phi/kernels/funcs/pooling.h | 4 +- 2 files changed, 342 insertions(+), 127 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index d8cc11e02e..b7b5dbd5b0 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/platform/fast_divmod.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" namespace phi { namespace funcs { @@ -134,7 +135,6 @@ __global__ void KernelPool2D(const int nthreads, FastDivModForPooling divmods, PoolProcess pool_process, bool exclusive, - bool adaptive, T* output_data, bool channel_last = false) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; @@ -154,19 +154,12 @@ __global__ void KernelPool2D(const int nthreads, &input_offset); input_data += input_offset; - if (adaptive) { - hstart = AdaptStartIndex(h_offset, input_height, output_height); - hend = AdaptEndIndex(h_offset, input_height, output_height); - wstart = AdaptStartIndex(w_offset, input_width, output_width); - wend = AdaptEndIndex(w_offset, input_width, output_width); - } else { - hstart = h_offset * stride_height - padding_height; - hend = min(hstart + ksize_height, input_height); - hstart = max(hstart, 0); - wstart = w_offset * stride_width - padding_width; - wend = min(wstart + ksize_width, input_width); - wstart = max(wstart, 0); - } + hstart = h_offset * stride_height - padding_height; + hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + wstart = w_offset * stride_width - padding_width; + wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); T ele = pool_process.initial(); for (int h = hstart; h < hend; ++h) { @@ -177,13 +170,74 @@ __global__ void KernelPool2D(const int nthreads, pool_process.compute(input_data[input_idx], &ele); } } - int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart) - : ksize_height * ksize_width; + int pool_size = exclusive ? (hend - hstart) * (wend - wstart) + : ksize_height * ksize_width; pool_process.finalize(static_cast(pool_size), &ele); output_data[index] = ele; } } +template +__global__ void AdaptiveKernelPool2D(const int nthreads, + const T* input_data, + const int channels, + const int input_height, + const int input_width, + const int output_height, + const int output_width, + const int ksize_height, + const int ksize_width, + const int stride_height, + const int stride_width, + const int padding_height, + const int padding_width, + FastDivModForPooling divmods, + PoolProcess pool_process, + bool exclusive, + T* output_data, + bool channel_last = false) { + const int n_offset = blockIdx.y; + const int c_offset = blockIdx.x * blockDim.y + threadIdx.y; + if (c_offset >= channels) { + return; + } + int hstart, hend, wstart, wend; + int input_offset = + channel_last + ? n_offset * input_height * input_width * channels + : (n_offset * channels + c_offset) * input_height * input_width; + int output_offset = + channel_last + ? n_offset * output_height * output_width * channels + : (n_offset * channels + c_offset) * output_height * output_width; + for (int hw_offset = threadIdx.x; hw_offset < output_height * output_width; + hw_offset += blockDim.x) { + int w_offset = hw_offset % output_width; + int h_offset = hw_offset / output_width; + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); + + T ele = pool_process.initial(); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + auto input_idx = channel_last + ? (h * input_width + w) * channels + c_offset + : h * input_width + w; + pool_process.compute(input_data[input_offset + input_idx], &ele); + } + } + int pool_size = (hend - hstart) * (wend - wstart); + pool_process.finalize(static_cast(pool_size), &ele); + int output_idx = + channel_last + ? (h_offset * output_width + w_offset) * channels + c_offset + : h_offset * output_width + w_offset; + output_data[output_offset + output_idx] = ele; + } +} + template __global__ void KernelPool2DGrad(const int nthreads, const T* __restrict__ input_data, @@ -408,35 +462,62 @@ void Pool2dDirectCUDAFunctor::operator()( const int padding_width = paddings[1]; int nthreads = batch_size * output_channels * output_height * output_width; - int thread_num = 1024; -#ifdef WITH_NV_JETSON - // backends::gpu::ChangeThreadNum(context, &thread_num); - thread_num = 512; -#endif - int blocks = (nthreads + thread_num - 1) / thread_num; - dim3 threads(thread_num, 1); - dim3 grid(blocks, 1); - auto pool_divmods = FastDivModForPooling(input_channels, output_width, output_height); - KernelPool2D<<>>(nthreads, - input, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - pool_divmods, - pool_compute, - exclusive, - adaptive, - output); + if (adaptive) { + int max_threads = 512; + int thread_num = + std::min(phi::funcs::details::GetLastPow2(output_height * output_width), + max_threads); + int blocks = std::min(max_threads / thread_num, output_channels); + dim3 threads(thread_num, blocks, 1); + dim3 grid( + std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1); + AdaptiveKernelPool2D + <<>>(nthreads, + input, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_compute, + exclusive, + output); + } else { + int thread_num = 1024; +#ifdef WITH_NV_JETSON + // backends::gpu::ChangeThreadNum(context, &thread_num); + thread_num = 512; +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + KernelPool2D<<>>(nthreads, + input, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_compute, + exclusive, + output); + } } /* @@ -476,35 +557,62 @@ class Pool2dFunctor { T* output_data = context.template Alloc(output); int nthreads = batch_size * output_channels * output_height * output_width; - int thread_num = 1024; -#ifdef WITH_NV_JETSON - backends::gpu::ChangeThreadNum(context, &thread_num); -#endif - int blocks = (nthreads + thread_num - 1) / thread_num; - dim3 threads(thread_num, 1); - dim3 grid(blocks, 1); - auto pool_divmods = FastDivModForPooling(input_channels, output_width, output_height); - KernelPool2D - <<>>(nthreads, - input_data, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - pool_divmods, - pool_process, - exclusive, - adaptive, - output_data); + if (adaptive) { + int max_threads = 512; + int thread_num = std::min( + phi::funcs::details::GetLastPow2(output_height * output_width), + max_threads); + int blocks = std::min(max_threads / thread_num, output_channels); + dim3 threads(thread_num, blocks, 1); + dim3 grid( + std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1); + AdaptiveKernelPool2D + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_process, + exclusive, + output_data); + } else { + int thread_num = 1024; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + KernelPool2D + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_process, + exclusive, + output_data); + } } void operator()(const phi::GPUContext& context, const DenseTensor& input, @@ -543,36 +651,64 @@ class Pool2dFunctor { T* output_data = context.template Alloc(output); int nthreads = batch_size * output_channels * output_height * output_width; - int thread_num = 1024; -#ifdef WITH_NV_JETSON - backends::gpu::ChangeThreadNum(context, &thread_num); -#endif - int blocks = (nthreads + thread_num - 1) / thread_num; - dim3 threads(thread_num, 1); - dim3 grid(blocks, 1); - auto pool_divmods = FastDivModForPooling(input_channels, output_width, output_height); - KernelPool2D - <<>>(nthreads, - input_data, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - pool_divmods, - pool_process, - exclusive, - adaptive, - output_data, - channel_last); + if (adaptive) { + int max_threads = 512; + int thread_num = std::min( + phi::funcs::details::GetLastPow2(output_height * output_width), + max_threads); + int blocks = std::min(max_threads / thread_num, output_channels); + dim3 threads(thread_num, blocks, 1); + dim3 grid( + std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1); + AdaptiveKernelPool2D + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_process, + exclusive, + output_data, + channel_last); + } else { + int thread_num = 1024; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + KernelPool2D + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + pool_divmods, + pool_process, + exclusive, + output_data, + channel_last); + } } }; /* @@ -1818,6 +1954,59 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, } } +template +__global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads, + const T1* input_data, + const int channels, + const int input_height, + const int input_width, + const int output_height, + const int output_width, + const int ksize_height, + const int ksize_width, + const int stride_height, + const int stride_width, + const int padding_height, + const int padding_width, + T1* output_data, + T2* mask_data, + FastDivModForPooling divmods) { + const int n_offset = blockIdx.y; + const int c_offset = blockIdx.x * blockDim.y + threadIdx.y; + if (c_offset >= channels) { + return; + } + int hstart, hend, wstart, wend; + int input_offset = + (n_offset * channels + c_offset) * input_height * input_width; + int output_offset = + (n_offset * channels + c_offset) * output_height * output_width; + for (int hw_offset = threadIdx.x; hw_offset < output_height * output_width; + hw_offset += blockDim.x) { + int w_offset = hw_offset % output_width; + int h_offset = hw_offset / output_width; + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); + + T1 ele = -FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * input_width + w; + if (ele < input_data[input_offset + input_index]) { + max_index = input_index; + ele = input_data[input_offset + input_index]; + } + } + } + int output_idx = output_offset + h_offset * output_width + w_offset; + output_data[output_idx] = ele; + mask_data[output_idx] = max_index; + } +} + template __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, const T1* output_grad, @@ -1922,35 +2111,61 @@ class MaxPool2dWithIndexFunctor { T2* mask_data = context.template Alloc(mask); int nthreads = batch_size * output_channels * output_height * output_width; - int thread_num = 1024; -#ifdef WITH_NV_JETSON - backends::gpu::ChangeThreadNum(context, &thread_num); -#endif - - int blocks = (nthreads + thread_num - 1) / thread_num; - dim3 threads(thread_num, 1); - dim3 grid(blocks, 1); - auto pool_divmods = FastDivModForPooling(input_channels, output_width, output_height); - KernelMaxPool2dWithIdx - <<>>(nthreads, - input_data, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - adaptive, - output_data, - mask_data, - pool_divmods); + if (adaptive && output_height > 1 && output_width > 1) { + int max_threads = 512; + int thread_num = std::min( + phi::funcs::details::GetLastPow2(output_height * output_width), + max_threads); + int blocks = std::min(max_threads / thread_num, output_channels); + dim3 threads(thread_num, blocks, 1); + dim3 grid( + std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1); + AdaptiveKernelMaxPool2dWithIdx + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + output_data, + mask_data, + pool_divmods); + } else { + int thread_num = 1024; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + KernelMaxPool2dWithIdx + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + adaptive, + output_data, + mask_data, + pool_divmods); + } } }; diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 0eebfc8568..1d1eacd0d5 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -92,12 +92,12 @@ class AvgPoolGrad { */ HOSTDEVICE inline int AdaptStartIndex(int ph, int input_size, int output_size) { return static_cast( - floor(static_cast(ph * input_size) / output_size)); + floor(static_cast(ph * input_size) / output_size)); } HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { return static_cast( - ceil(static_cast((ph + 1) * input_size) / output_size)); + ceil(static_cast((ph + 1) * input_size) / output_size)); } /* -- GitLab