未验证 提交 730ccd9e 编写于 作者: J jakpiase 提交者: GitHub

Fix for undefined format for 6 dim tensor (#38553)

* 6 dims fix

* removed limitations of max dims
上级 31efec53
......@@ -269,9 +269,13 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::ncdhw;
} else {
return dnnl::memory::format_tag::ndhwc;
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
......@@ -310,6 +314,10 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
......@@ -397,7 +405,9 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
return MKLDNNMemoryFormat::ndhwc;
}
} else if (dims_size == 6) {
return MKLDNNMemoryFormat::abcdef;
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::abcdef;
}
}
return data_format;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册