未验证 提交 b1d3ec16 编写于 作者: Z zhangyikun02 提交者: GitHub

fix bug for pool2d and pool2d_grad when kernel_size > in_h/in_w in xpu (#53043)

上级 28492bf6
......@@ -154,11 +154,11 @@ void Pool2dGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad");
} else {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
if (kernel_size[0] > (in_h + paddings[0] + paddings[1])) {
kernel_size[0] = in_h + paddings[0] + paddings[1];
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
if (kernel_size[1] > (in_w + paddings[2] + paddings[3])) {
kernel_size[1] = in_w + paddings[2] + paddings[3];
}
if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
......
......@@ -90,11 +90,11 @@ void Pool2dKernel(const Context& ctx,
int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS;
if (!adaptive) {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
if (kernel_size[0] > (in_h + paddings[0] + paddings[1])) {
kernel_size[0] = in_h + paddings[0] + paddings[1];
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
if (kernel_size[1] > (in_w + paddings[2] + paddings[3])) {
kernel_size[1] = in_w + paddings[2] + paddings[3];
}
if (pooling_type == "max") {
r = xpu::max_pool2d<XPUType>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册