未验证 提交 5f388221 编写于 作者: Z zhangyikun02 提交者: GitHub

pool2d and pool2d_grad support case of kernel_size > kh/kw for xpu (#51870)

上级 4dfbdb04
...@@ -157,6 +157,12 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -157,6 +157,12 @@ void Pool2dGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad");
} else { } else {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
}
if (pooling_type == "max") { if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api // TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>( r = xpu::max_pool2d_grad<XPUType>(
......
...@@ -90,6 +90,12 @@ void Pool2dKernel(const Context& ctx, ...@@ -90,6 +90,12 @@ void Pool2dKernel(const Context& ctx,
int* index_data = nullptr; int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS; int r = xpu::Error_t::SUCCESS;
if (!adaptive) { if (!adaptive) {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
}
if (pooling_type == "max") { if (pooling_type == "max") {
r = xpu::max_pool2d<XPUType>( r = xpu::max_pool2d<XPUType>(
ctx.x_context(), ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册