未验证 提交 a8da1625 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Fix complation error of XPU from Pool2dGradKernel (#45727)

* [OpAttr]Fix complation error of XPU from Pool2dGradKernel

* test=kunlun
上级 24a2bedb
...@@ -24,7 +24,7 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -24,7 +24,7 @@ void Pool2dGradKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
const std::vector<int>& kernel_size_t, const IntArray& kernel_size_t,
const std::vector<int>& strides_t, const std::vector<int>& strides_t,
const std::vector<int>& paddings_t, const std::vector<int>& paddings_t,
bool ceil_mode, bool ceil_mode,
...@@ -38,7 +38,8 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -38,7 +38,8 @@ void Pool2dGradKernel(const Context& ctx,
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings(paddings_t); std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size(kernel_size_t); std::vector<int> kernel_size(kernel_size_t.GetData().begin(),
kernel_size_t.GetData().end());
std::vector<int> strides(strides_t); std::vector<int> strides(strides_t);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -22,7 +22,7 @@ namespace phi { ...@@ -22,7 +22,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void Pool2dKernel(const Context& ctx, void Pool2dKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& kernel_size, const IntArray& kernel_size_t,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings_t, const std::vector<int>& paddings_t,
bool ceil_mode, bool ceil_mode,
...@@ -36,8 +36,8 @@ void Pool2dKernel(const Context& ctx, ...@@ -36,8 +36,8 @@ void Pool2dKernel(const Context& ctx,
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings(paddings_t); std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size_val(kernel_size.GetData().begin(), std::vector<int> kernel_size(kernel_size_t.GetData().begin(),
kernel_size.GetData().end()); kernel_size_t.GetData().end());
PADDLE_ENFORCE_EQ(kernel_size.size(), PADDLE_ENFORCE_EQ(kernel_size.size(),
2, 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册