未验证 提交 730ccd9e 编写于 作者: J jakpiase 提交者: GitHub

Fix for undefined format for 6 dim tensor (#38553)

* 6 dims fix

* removed limitations of max dims
上级 31efec53
...@@ -269,9 +269,13 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { ...@@ -269,9 +269,13 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::ncdhw; return dnnl::memory::format_tag::abcde;
} else { } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
return dnnl::memory::format_tag::ndhwc; strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) { if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
...@@ -310,6 +314,10 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { ...@@ -310,6 +314,10 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
strides[2] >= strides[3] && strides[3] >= strides[4] && strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) { strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef; return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
} }
} }
} }
...@@ -397,8 +405,10 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, ...@@ -397,8 +405,10 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
return MKLDNNMemoryFormat::ndhwc; return MKLDNNMemoryFormat::ndhwc;
} }
} else if (dims_size == 6) { } else if (dims_size == 6) {
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::abcdef; return MKLDNNMemoryFormat::abcdef;
} }
}
return data_format; return data_format;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册