From aa0098f604f7327f8485f8944ba6c7b1e8be901f Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:15:02 +0800 Subject: [PATCH] Optimize performance of avgpool2d with NHWC layout (#49231) --- paddle/phi/kernels/impl/pool_kernel_impl.h | 26 ++++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index 931a14d9fd..7f08f0bd79 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -29,20 +29,22 @@ namespace phi { inline int GetReduceNum(const DenseTensor& input, const DenseTensor* output, - const std::string data_format, + const bool channel_last, std::vector* reduce_dim) { - // data_format only can be NCHW - bool channel_last = (data_format == "NHWC"); - if (channel_last) { - return 0; - } int reduce_num = 0; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int output_height = + channel_last ? output->dims()[1] : output->dims()[2]; + const int output_width = channel_last ? output->dims()[2] : output->dims()[3]; if ((output_height == 1) && (output_width == 1)) { - reduce_dim->push_back(2); - reduce_dim->push_back(3); - reduce_num = input.dims()[2] * input.dims()[3]; + if (channel_last) { + reduce_dim->push_back(1); + reduce_dim->push_back(2); + reduce_num = input.dims()[1] * input.dims()[2]; + } else { + reduce_dim->push_back(2); + reduce_dim->push_back(3); + reduce_num = input.dims()[2] * input.dims()[3]; + } } return reduce_num; } @@ -109,7 +111,7 @@ void PoolRawKernel(const Context& ctx, } else if (pooling_type == "avg") { std::vector reduce_dim; - int reduce_num = GetReduceNum(x, out, data_format, &reduce_dim); + int reduce_num = GetReduceNum(x, out, channel_last, &reduce_dim); if (reduce_num > 0 && adaptive) { // for adaptive_avg_pool2d && output_size == 1 #if defined(__HIPCC__) || defined(__NVCC__) -- GitLab