未验证 提交 4af7ebf4 编写于 作者: P piotrekobi 提交者: GitHub

Disable oneDNN adaptive pooling exhaustive check (#43236)

上级 4c2c2148
......@@ -35,27 +35,8 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
auto src_tz = phi::vectorize(ctx.Input<Tensor>("X")->dims());
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
// Fast but not exhustive check
if ((src_tz[src_tz.size() - 1] % ksize[1] == 0) &&
(src_tz[src_tz.size() - 2] % ksize[0] == 0))
return true;
// Exhustive check
auto IH = static_cast<double>(src_tz[src_tz.size() - 2]);
auto IW = static_cast<double>(src_tz[src_tz.size() - 1]);
auto OH = static_cast<double>(ksize[0]);
auto OW = static_cast<double>(ksize[1]);
auto SH = static_cast<int>(floor((IH * 2.0) / OH) - floor(IH / OH));
auto SW = static_cast<int>(floor((IW * 2.0) / OW) - floor(IW / OW));
auto KH = static_cast<int>(ceil((IH * 2.0) / OH) - floor(IH / OH));
auto KW = static_cast<int>(ceil((IW * 2.0) / OW) - floor(IW / OW));
auto PH = (SH * (static_cast<int>(OH) - 1) + KH - static_cast<int>(IH));
auto PW = (SW * (static_cast<int>(OW) - 1) + KW - static_cast<int>(IW));
// If there is additional padding needed then
// this is situation that oneDNN cannot comply with
// paddlepaddle reference implementation
return (PH == 0) && (PW == 0);
return ((src_tz[src_tz.size() - 1] % ksize[1] == 0) &&
(src_tz[src_tz.size() - 2] % ksize[0] == 0));
}
framework::OpKernelType PoolOp::GetExpectedKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册