diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 30ead84d1a9871c12d773b419c24f710719f69a7..9aa68881e44a047327be14a27e53ce5621e62c2b 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -35,27 +35,8 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { auto src_tz = phi::vectorize(ctx.Input("X")->dims()); std::vector ksize = ctx.Attr>("ksize"); // Fast but not exhustive check - if ((src_tz[src_tz.size() - 1] % ksize[1] == 0) && - (src_tz[src_tz.size() - 2] % ksize[0] == 0)) - return true; - - // Exhustive check - auto IH = static_cast(src_tz[src_tz.size() - 2]); - auto IW = static_cast(src_tz[src_tz.size() - 1]); - auto OH = static_cast(ksize[0]); - auto OW = static_cast(ksize[1]); - - auto SH = static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); - auto SW = static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); - auto KH = static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); - auto KW = static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); - - auto PH = (SH * (static_cast(OH) - 1) + KH - static_cast(IH)); - auto PW = (SW * (static_cast(OW) - 1) + KW - static_cast(IW)); - // If there is additional padding needed then - // this is situation that oneDNN cannot comply with - // paddlepaddle reference implementation - return (PH == 0) && (PW == 0); + return ((src_tz[src_tz.size() - 1] % ksize[1] == 0) && + (src_tz[src_tz.size() - 2] % ksize[0] == 0)); } framework::OpKernelType PoolOp::GetExpectedKernelType(