未验证 提交 a8da1625 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Fix complation error of XPU from Pool2dGradKernel (#45727)

* [OpAttr]Fix complation error of XPU from Pool2dGradKernel

* test=kunlun
上级 24a2bedb
......@@ -24,7 +24,7 @@ void Pool2dGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int>& kernel_size_t,
const IntArray& kernel_size_t,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool ceil_mode,
......@@ -38,7 +38,8 @@ void Pool2dGradKernel(const Context& ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size(kernel_size_t);
std::vector<int> kernel_size(kernel_size_t.GetData().begin(),
kernel_size_t.GetData().end());
std::vector<int> strides(strides_t);
PADDLE_ENFORCE_EQ(
......
......@@ -22,7 +22,7 @@ namespace phi {
template <typename T, typename Context>
void Pool2dKernel(const Context& ctx,
const DenseTensor& x,
const IntArray& kernel_size,
const IntArray& kernel_size_t,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
bool ceil_mode,
......@@ -36,8 +36,8 @@ void Pool2dKernel(const Context& ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size_val(kernel_size.GetData().begin(),
kernel_size.GetData().end());
std::vector<int> kernel_size(kernel_size_t.GetData().begin(),
kernel_size_t.GetData().end());
PADDLE_ENFORCE_EQ(kernel_size.size(),
2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册