From 5f3882216915945819b842cc2a12b6f0b4ee6119 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Thu, 23 Mar 2023 15:23:21 +0800 Subject: [PATCH] pool2d and pool2d_grad support case of kernel_size > kh/kw for xpu (#51870) --- paddle/phi/kernels/xpu/pool_grad_kernel.cc | 6 ++++++ paddle/phi/kernels/xpu/pool_kernel.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index dfea5723156..a94a757dc8b 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -157,6 +157,12 @@ void Pool2dGradKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad"); } else { + if (kernel_size[0] > in_h) { + kernel_size[0] = in_h; + } + if (kernel_size[1] > in_w) { + kernel_size[1] = in_w; + } if (pooling_type == "max") { // TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api r = xpu::max_pool2d_grad( diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index ad09a5ed371..ff324746939 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -90,6 +90,12 @@ void Pool2dKernel(const Context& ctx, int* index_data = nullptr; int r = xpu::Error_t::SUCCESS; if (!adaptive) { + if (kernel_size[0] > in_h) { + kernel_size[0] = in_h; + } + if (kernel_size[1] > in_w) { + kernel_size[1] = in_w; + } if (pooling_type == "max") { r = xpu::max_pool2d( ctx.x_context(), -- GitLab