未验证 提交 29b55009 编写于 作者: J Jacek Czaja 提交者: GitHub

Persuading more efficient memory format to be preferred (#44078)

* - blind shot fix

* - workaround

* - compilation fix

* - Hack
上级 d90c39e9
...@@ -233,15 +233,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { ...@@ -233,15 +233,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) { strides[2] >= strides[3]) {
return dnnl::memory::format_tag::nchw; return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] && } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) { strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba; return dnnl::memory::format_tag::cdba;
} else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
strides[0] >= strides[1]) {
return dnnl::memory::format_tag::dcab;
} else { } else {
return dnnl::memory::format_tag::nhwc; return dnnl::memory::format_tag::dcab;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
......
...@@ -51,15 +51,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) { ...@@ -51,15 +51,21 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) { strides[2] >= strides[3]) {
return dnnl::memory::format_tag::nchw; return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] && } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) { strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba; return dnnl::memory::format_tag::cdba;
} else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
strides[0] >= strides[1]) {
return dnnl::memory::format_tag::dcab;
} else { } else {
return dnnl::memory::format_tag::nhwc; return dnnl::memory::format_tag::dcab;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册