未验证 提交 6d067860 编写于 作者: O Ouyang Chao 提交者: GitHub

【PFCC算子性能优化】为Paddle优化adaptive_pooling_op性能 (#45959)

* optimize adaptive_pooling_op (forward)

* fix bug of AdaptiveKernelMaxPool2dWithIdx

* fix bug of AdaptiveKernelPool2D
上级 3ad6994d
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -134,7 +135,6 @@ __global__ void KernelPool2D(const int nthreads, ...@@ -134,7 +135,6 @@ __global__ void KernelPool2D(const int nthreads,
FastDivModForPooling divmods, FastDivModForPooling divmods,
PoolProcess pool_process, PoolProcess pool_process,
bool exclusive, bool exclusive,
bool adaptive,
T* output_data, T* output_data,
bool channel_last = false) { bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
...@@ -154,19 +154,12 @@ __global__ void KernelPool2D(const int nthreads, ...@@ -154,19 +154,12 @@ __global__ void KernelPool2D(const int nthreads,
&input_offset); &input_offset);
input_data += input_offset; input_data += input_offset;
if (adaptive) { hstart = h_offset * stride_height - padding_height;
hstart = AdaptStartIndex(h_offset, input_height, output_height); hend = min(hstart + ksize_height, input_height);
hend = AdaptEndIndex(h_offset, input_height, output_height); hstart = max(hstart, 0);
wstart = AdaptStartIndex(w_offset, input_width, output_width); wstart = w_offset * stride_width - padding_width;
wend = AdaptEndIndex(w_offset, input_width, output_width); wend = min(wstart + ksize_width, input_width);
} else { wstart = max(wstart, 0);
hstart = h_offset * stride_height - padding_height;
hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
wstart = w_offset * stride_width - padding_width;
wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
T ele = pool_process.initial(); T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
...@@ -177,13 +170,74 @@ __global__ void KernelPool2D(const int nthreads, ...@@ -177,13 +170,74 @@ __global__ void KernelPool2D(const int nthreads,
pool_process.compute(input_data[input_idx], &ele); pool_process.compute(input_data[input_idx], &ele);
} }
} }
int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart) int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width; : ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele; output_data[index] = ele;
} }
} }
template <typename PoolProcess, typename T>
__global__ void AdaptiveKernelPool2D(const int nthreads,
const T* input_data,
const int channels,
const int input_height,
const int input_width,
const int output_height,
const int output_width,
const int ksize_height,
const int ksize_width,
const int stride_height,
const int stride_width,
const int padding_height,
const int padding_width,
FastDivModForPooling divmods,
PoolProcess pool_process,
bool exclusive,
T* output_data,
bool channel_last = false) {
const int n_offset = blockIdx.y;
const int c_offset = blockIdx.x * blockDim.y + threadIdx.y;
if (c_offset >= channels) {
return;
}
int hstart, hend, wstart, wend;
int input_offset =
channel_last
? n_offset * input_height * input_width * channels
: (n_offset * channels + c_offset) * input_height * input_width;
int output_offset =
channel_last
? n_offset * output_height * output_width * channels
: (n_offset * channels + c_offset) * output_height * output_width;
for (int hw_offset = threadIdx.x; hw_offset < output_height * output_width;
hw_offset += blockDim.x) {
int w_offset = hw_offset % output_width;
int h_offset = hw_offset / output_width;
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
auto input_idx = channel_last
? (h * input_width + w) * channels + c_offset
: h * input_width + w;
pool_process.compute(input_data[input_offset + input_idx], &ele);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(static_cast<T>(pool_size), &ele);
int output_idx =
channel_last
? (h_offset * output_width + w_offset) * channels + c_offset
: h_offset * output_width + w_offset;
output_data[output_offset + output_idx] = ele;
}
}
template <typename T, typename PoolProcess> template <typename T, typename PoolProcess>
__global__ void KernelPool2DGrad(const int nthreads, __global__ void KernelPool2DGrad(const int nthreads,
const T* __restrict__ input_data, const T* __restrict__ input_data,
...@@ -408,35 +462,62 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -408,35 +462,62 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
const int padding_width = paddings[1]; const int padding_width = paddings[1];
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
// backends::gpu::ChangeThreadNum(context, &thread_num);
thread_num = 512;
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods = auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height); FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads, if (adaptive) {
input, int max_threads = 512;
input_channels, int thread_num =
input_height, std::min(phi::funcs::details::GetLastPow2(output_height * output_width),
input_width, max_threads);
output_height, int blocks = std::min(max_threads / thread_num, output_channels);
output_width, dim3 threads(thread_num, blocks, 1);
ksize_height, dim3 grid(
ksize_width, std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
stride_height, AdaptiveKernelPool2D<PoolProcess, T>
stride_width, <<<grid, threads, 0, stream>>>(nthreads,
padding_height, input,
padding_width, input_channels,
pool_divmods, input_height,
pool_compute, input_width,
exclusive, output_height,
adaptive, output_width,
output); ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_compute,
exclusive,
output);
} else {
int thread_num = 1024;
#ifdef WITH_NV_JETSON
// backends::gpu::ChangeThreadNum(context, &thread_num);
thread_num = 512;
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
input,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_compute,
exclusive,
output);
}
} }
/* /*
...@@ -476,35 +557,62 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> { ...@@ -476,35 +557,62 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
T* output_data = context.template Alloc<T>(output); T* output_data = context.template Alloc<T>(output);
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods = auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height); FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T> if (adaptive) {
<<<grid, threads, 0, context.stream()>>>(nthreads, int max_threads = 512;
input_data, int thread_num = std::min(
input_channels, phi::funcs::details::GetLastPow2(output_height * output_width),
input_height, max_threads);
input_width, int blocks = std::min(max_threads / thread_num, output_channels);
output_height, dim3 threads(thread_num, blocks, 1);
output_width, dim3 grid(
ksize_height, std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
ksize_width, AdaptiveKernelPool2D<PoolProcess, T>
stride_height, <<<grid, threads, 0, context.stream()>>>(nthreads,
stride_width, input_data,
padding_height, input_channels,
padding_width, input_height,
pool_divmods, input_width,
pool_process, output_height,
exclusive, output_width,
adaptive, ksize_height,
output_data); ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_process,
exclusive,
output_data);
} else {
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
KernelPool2D<PoolProcess, T>
<<<grid, threads, 0, context.stream()>>>(nthreads,
input_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_process,
exclusive,
output_data);
}
} }
void operator()(const phi::GPUContext& context, void operator()(const phi::GPUContext& context,
const DenseTensor& input, const DenseTensor& input,
...@@ -543,36 +651,64 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> { ...@@ -543,36 +651,64 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
T* output_data = context.template Alloc<T>(output); T* output_data = context.template Alloc<T>(output);
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods = auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height); FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T> if (adaptive) {
<<<grid, threads, 0, context.stream()>>>(nthreads, int max_threads = 512;
input_data, int thread_num = std::min(
input_channels, phi::funcs::details::GetLastPow2(output_height * output_width),
input_height, max_threads);
input_width, int blocks = std::min(max_threads / thread_num, output_channels);
output_height, dim3 threads(thread_num, blocks, 1);
output_width, dim3 grid(
ksize_height, std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
ksize_width, AdaptiveKernelPool2D<PoolProcess, T>
stride_height, <<<grid, threads, 0, context.stream()>>>(nthreads,
stride_width, input_data,
padding_height, input_channels,
padding_width, input_height,
pool_divmods, input_width,
pool_process, output_height,
exclusive, output_width,
adaptive, ksize_height,
output_data, ksize_width,
channel_last); stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_process,
exclusive,
output_data,
channel_last);
} else {
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
KernelPool2D<PoolProcess, T>
<<<grid, threads, 0, context.stream()>>>(nthreads,
input_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
pool_divmods,
pool_process,
exclusive,
output_data,
channel_last);
}
} }
}; };
/* /*
...@@ -1818,6 +1954,59 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, ...@@ -1818,6 +1954,59 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads,
} }
} }
template <typename T1, typename T2>
__global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads,
const T1* input_data,
const int channels,
const int input_height,
const int input_width,
const int output_height,
const int output_width,
const int ksize_height,
const int ksize_width,
const int stride_height,
const int stride_width,
const int padding_height,
const int padding_width,
T1* output_data,
T2* mask_data,
FastDivModForPooling divmods) {
const int n_offset = blockIdx.y;
const int c_offset = blockIdx.x * blockDim.y + threadIdx.y;
if (c_offset >= channels) {
return;
}
int hstart, hend, wstart, wend;
int input_offset =
(n_offset * channels + c_offset) * input_height * input_width;
int output_offset =
(n_offset * channels + c_offset) * output_height * output_width;
for (int hw_offset = threadIdx.x; hw_offset < output_height * output_width;
hw_offset += blockDim.x) {
int w_offset = hw_offset % output_width;
int h_offset = hw_offset / output_width;
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
T1 ele = -FLT_MAX;
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * input_width + w;
if (ele < input_data[input_offset + input_index]) {
max_index = input_index;
ele = input_data[input_offset + input_index];
}
}
}
int output_idx = output_offset + h_offset * output_width + w_offset;
output_data[output_idx] = ele;
mask_data[output_idx] = max_index;
}
}
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
const T1* output_grad, const T1* output_grad,
...@@ -1922,35 +2111,61 @@ class MaxPool2dWithIndexFunctor<phi::GPUContext, T1, T2> { ...@@ -1922,35 +2111,61 @@ class MaxPool2dWithIndexFunctor<phi::GPUContext, T1, T2> {
T2* mask_data = context.template Alloc<T2>(mask); T2* mask_data = context.template Alloc<T2>(mask);
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods = auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height); FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2dWithIdx<T1, T2> if (adaptive && output_height > 1 && output_width > 1) {
<<<grid, threads, 0, context.stream()>>>(nthreads, int max_threads = 512;
input_data, int thread_num = std::min(
input_channels, phi::funcs::details::GetLastPow2(output_height * output_width),
input_height, max_threads);
input_width, int blocks = std::min(max_threads / thread_num, output_channels);
output_height, dim3 threads(thread_num, blocks, 1);
output_width, dim3 grid(
ksize_height, std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
ksize_width, AdaptiveKernelMaxPool2dWithIdx<T1, T2>
stride_height, <<<grid, threads, 0, context.stream()>>>(nthreads,
stride_width, input_data,
padding_height, input_channels,
padding_width, input_height,
adaptive, input_width,
output_data, output_height,
mask_data, output_width,
pool_divmods); ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
output_data,
mask_data,
pool_divmods);
} else {
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
KernelMaxPool2dWithIdx<T1, T2>
<<<grid, threads, 0, context.stream()>>>(nthreads,
input_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
adaptive,
output_data,
mask_data,
pool_divmods);
}
} }
}; };
......
...@@ -92,12 +92,12 @@ class AvgPoolGrad { ...@@ -92,12 +92,12 @@ class AvgPoolGrad {
*/ */
HOSTDEVICE inline int AdaptStartIndex(int ph, int input_size, int output_size) { HOSTDEVICE inline int AdaptStartIndex(int ph, int input_size, int output_size) {
return static_cast<int>( return static_cast<int>(
floor(static_cast<double>(ph * input_size) / output_size)); floor(static_cast<float>(ph * input_size) / output_size));
} }
HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
return static_cast<int>( return static_cast<int>(
ceil(static_cast<double>((ph + 1) * input_size) / output_size)); ceil(static_cast<float>((ph + 1) * input_size) / output_size));
} }
/* /*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册