未验证 提交 ec1376ae 编写于 作者: J jakpiase 提交者: GitHub

reverted changes (#46254)

上级 00f32af3
...@@ -77,25 +77,8 @@ class ConcatMKLDNNHandler ...@@ -77,25 +77,8 @@ class ConcatMKLDNNHandler
} }
auto dst_dims = phi::vectorize<int64_t>(output->dims()); auto dst_dims = phi::vectorize<int64_t>(output->dims());
dnnl::memory::desc dst_md =
dnnl::memory::desc dst_md; memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
// if concat is being used as a stack op(all source memories dims on
// concat_axis are equal to 1), then it may choose a non-optimal memory
// format tag for destination, because concat primitive is chosing it based
// on source memory descriptors and f.e.200x1x10 can be described as both
// abc and bac and both would be using exact same physical layout, but in
// that scenario bac will be chosen for destination no matter which
// formats are being set in inputs. In that scenario we are enforcing using
// a dense format, because it is the most common one and should be the best
// in terms of the performance
const auto src0_tz = srcs_md[0].dims();
if (std::find(src0_tz.begin(), src0_tz.end(), 1) != src0_tz.end()) {
dst_md = memory::desc(
dst_dims, dt, platform::GetPlainMKLDNNFormat(dst_dims.size()));
} else {
dst_md = memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
}
this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册