From 730ccd9e8e1c8a03c82af9000fed133f3b4f4a16 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 31 Dec 2021 10:14:17 +0100 Subject: [PATCH] Fix for undefined format for 6 dim tensor (#38553) * 6 dims fix * removed limitations of max dims --- paddle/fluid/platform/mkldnn_helper.h | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index b98ca33285a..7a528cf8d6b 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -269,9 +269,13 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2] && strides[2] >= strides[3] && strides[3] >= strides[4]) { - return dnnl::memory::format_tag::ncdhw; - } else { - return dnnl::memory::format_tag::ndhwc; + return dnnl::memory::format_tag::abcde; + } else if (strides[0] >= strides[2] && strides[2] >= strides[1] && + strides[1] >= strides[3] && strides[3] >= strides[4]) { + return dnnl::memory::format_tag::acbde; + } else if (strides[0] >= strides[2] && strides[2] >= strides[3] && + strides[3] >= strides[4] && strides[4] >= strides[1]) { + return dnnl::memory::format_tag::acdeb; } } else if (inner_nblks == 1) { if (inner_blks[0] == 8 && inner_idxs[0] == 0) { @@ -310,6 +314,10 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { strides[2] >= strides[3] && strides[3] >= strides[4] && strides[4] >= strides[5]) { return dnnl::memory::format_tag::abcdef; + } else if (strides[0] >= strides[2] && strides[2] >= strides[1] && + strides[1] >= strides[3] && strides[3] >= strides[4] && + strides[4] >= strides[5]) { + return dnnl::memory::format_tag::acbdef; } } } @@ -397,7 +405,9 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, return MKLDNNMemoryFormat::ndhwc; } } else if (dims_size == 6) { - return MKLDNNMemoryFormat::abcdef; + if (data_format == MKLDNNMemoryFormat::nchw) { + return MKLDNNMemoryFormat::abcdef; + } } return data_format; } -- GitLab