From 29b55009f78eb812fd13223975fd1a1aa7fafad0 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 5 Jul 2022 15:19:42 +0200 Subject: [PATCH] Persuading more efficient memory format to be preferred (#44078) * - blind shot fix * - workaround * - compilation fix * - Hack --- paddle/fluid/platform/mkldnn_helper.h | 16 +++++++++++----- paddle/fluid/platform/mkldnn_utils.h | 16 +++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 83fd353f54..0e97a68edf 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -233,15 +233,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2] && strides[2] >= strides[3]) { - return dnnl::memory::format_tag::nchw; + return dnnl::memory::format_tag::abcd; + } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && + strides[1] >= strides[0]) { + return dnnl::memory::format_tag::cdba; + } else if (strides[0] >= strides[2] && strides[2] >= strides[3] && + strides[3] >= strides[1]) { + return dnnl::memory::format_tag::acdb; + } else if (strides[0] >= strides[1] && strides[1] >= strides[3] && + strides[3] >= strides[2]) { + return dnnl::memory::format_tag::abdc; } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && strides[1] >= strides[0]) { return dnnl::memory::format_tag::cdba; - } else if (strides[3] >= strides[2] && strides[2] >= strides[0] && - strides[0] >= strides[1]) { - return dnnl::memory::format_tag::dcab; } else { - return dnnl::memory::format_tag::nhwc; + return dnnl::memory::format_tag::dcab; } } else if (inner_nblks == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) { diff --git a/paddle/fluid/platform/mkldnn_utils.h b/paddle/fluid/platform/mkldnn_utils.h index 12c48ed412..38470d18f4 100644 --- a/paddle/fluid/platform/mkldnn_utils.h +++ b/paddle/fluid/platform/mkldnn_utils.h @@ -51,15 +51,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2] && strides[2] >= strides[3]) { - return dnnl::memory::format_tag::nchw; + return dnnl::memory::format_tag::abcd; + } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && + strides[1] >= strides[0]) { + return dnnl::memory::format_tag::cdba; + } else if (strides[0] >= strides[2] && strides[2] >= strides[3] && + strides[3] >= strides[1]) { + return dnnl::memory::format_tag::acdb; + } else if (strides[0] >= strides[1] && strides[1] >= strides[3] && + strides[3] >= strides[2]) { + return dnnl::memory::format_tag::abdc; } else if (strides[2] >= strides[3] && strides[3] >= strides[1] && strides[1] >= strides[0]) { return dnnl::memory::format_tag::cdba; - } else if (strides[3] >= strides[2] && strides[2] >= strides[0] && - strides[0] >= strides[1]) { - return dnnl::memory::format_tag::dcab; } else { - return dnnl::memory::format_tag::nhwc; + return dnnl::memory::format_tag::dcab; } } else if (inner_nblks == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) { -- GitLab