From a8da1625b6b5046138ca343a21ce0d8f46f961ff Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 5 Sep 2022 18:58:34 +0800 Subject: [PATCH] [OpAttr]Fix complation error of XPU from Pool2dGradKernel (#45727) * [OpAttr]Fix complation error of XPU from Pool2dGradKernel * test=kunlun --- paddle/phi/kernels/xpu/pool_grad_kernel.cc | 5 +++-- paddle/phi/kernels/xpu/pool_kernel.cc | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index 312c7be34f..349fe1a0f1 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -24,7 +24,7 @@ void Pool2dGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& kernel_size_t, + const IntArray& kernel_size_t, const std::vector& strides_t, const std::vector& paddings_t, bool ceil_mode, @@ -38,7 +38,8 @@ void Pool2dGradKernel(const Context& ctx, using XPUType = typename XPUTypeTrait::Type; std::vector paddings(paddings_t); - std::vector kernel_size(kernel_size_t); + std::vector kernel_size(kernel_size_t.GetData().begin(), + kernel_size_t.GetData().end()); std::vector strides(strides_t); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 2eb850b9a7..9278484378 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void Pool2dKernel(const Context& ctx, const DenseTensor& x, - const IntArray& kernel_size, + const IntArray& kernel_size_t, const std::vector& strides, const std::vector& paddings_t, bool ceil_mode, @@ -36,8 +36,8 @@ void Pool2dKernel(const Context& ctx, using XPUType = typename XPUTypeTrait::Type; std::vector paddings(paddings_t); - std::vector kernel_size_val(kernel_size.GetData().begin(), - kernel_size.GetData().end()); + std::vector kernel_size(kernel_size_t.GetData().begin(), + kernel_size_t.GetData().end()); PADDLE_ENFORCE_EQ(kernel_size.size(), 2, -- GitLab