diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index 931a14d9fd872cc1af56b2ce23f500e906f08089..7f08f0bd793921ee6c8cc69698adebbf91aaa839 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -29,20 +29,22 @@ namespace phi { inline int GetReduceNum(const DenseTensor& input, const DenseTensor* output, - const std::string data_format, + const bool channel_last, std::vector* reduce_dim) { - // data_format only can be NCHW - bool channel_last = (data_format == "NHWC"); - if (channel_last) { - return 0; - } int reduce_num = 0; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int output_height = + channel_last ? output->dims()[1] : output->dims()[2]; + const int output_width = channel_last ? output->dims()[2] : output->dims()[3]; if ((output_height == 1) && (output_width == 1)) { - reduce_dim->push_back(2); - reduce_dim->push_back(3); - reduce_num = input.dims()[2] * input.dims()[3]; + if (channel_last) { + reduce_dim->push_back(1); + reduce_dim->push_back(2); + reduce_num = input.dims()[1] * input.dims()[2]; + } else { + reduce_dim->push_back(2); + reduce_dim->push_back(3); + reduce_num = input.dims()[2] * input.dims()[3]; + } } return reduce_num; } @@ -109,7 +111,7 @@ void PoolRawKernel(const Context& ctx, } else if (pooling_type == "avg") { std::vector reduce_dim; - int reduce_num = GetReduceNum(x, out, data_format, &reduce_dim); + int reduce_num = GetReduceNum(x, out, channel_last, &reduce_dim); if (reduce_num > 0 && adaptive) { // for adaptive_avg_pool2d && output_size == 1 #if defined(__HIPCC__) || defined(__NVCC__)