diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 83fd353f54dd6c2ad32a11bf36d13f7fe040f366..0e97a68edfc9df45f247373d3eecbae029ac4fe2 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -233,15 +233,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2] && strides[2] >= strides[3]) { - return dnnl::memory::format_tag::nchw; + return dnnl::memory::format_tag::abcd; + } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && + strides[1] >= strides[0]) { + return dnnl::memory::format_tag::cdba; + } else if (strides[0] >= strides[2] && strides[2] >= strides[3] && + strides[3] >= strides[1]) { + return dnnl::memory::format_tag::acdb; + } else if (strides[0] >= strides[1] && strides[1] >= strides[3] && + strides[3] >= strides[2]) { + return dnnl::memory::format_tag::abdc; } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && strides[1] >= strides[0]) { return dnnl::memory::format_tag::cdba; - } else if (strides[3] >= strides[2] && strides[2] >= strides[0] && - strides[0] >= strides[1]) { - return dnnl::memory::format_tag::dcab; } else { - return dnnl::memory::format_tag::nhwc; + return dnnl::memory::format_tag::dcab; } } else if (inner_nblks == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) { diff --git a/paddle/fluid/platform/mkldnn_utils.h b/paddle/fluid/platform/mkldnn_utils.h index 12c48ed412428776a0377f964239c4eb4be9562d..38470d18f46235526931ea6b4032caf66d881e31 100644 --- a/paddle/fluid/platform/mkldnn_utils.h +++ b/paddle/fluid/platform/mkldnn_utils.h @@ -51,15 +51,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2] && strides[2] >= strides[3]) { - return dnnl::memory::format_tag::nchw; + return dnnl::memory::format_tag::abcd; + } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && + strides[1] >= strides[0]) { + return dnnl::memory::format_tag::cdba; + } else if (strides[0] >= strides[2] && strides[2] >= strides[3] && + strides[3] >= strides[1]) { + return dnnl::memory::format_tag::acdb; + } else if (strides[0] >= strides[1] && strides[1] >= strides[3] && + strides[3] >= strides[2]) { + return dnnl::memory::format_tag::abdc; } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && strides[1] >= strides[0]) { return dnnl::memory::format_tag::cdba; - } else if (strides[3] >= strides[2] && strides[2] >= strides[0] && - strides[0] >= strides[1]) { - return dnnl::memory::format_tag::dcab; } else { - return dnnl::memory::format_tag::nhwc; + return dnnl::memory::format_tag::dcab; } } else if (inner_nblks == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) {