diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index b16576505dfd3f8023f039614ccf35b14364244e..7318e0410cec2be043240522759ddf0fc1dedb25 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -77,25 +77,8 @@ class ConcatMKLDNNHandler } auto dst_dims = phi::vectorize(output->dims()); - - dnnl::memory::desc dst_md; - - // if concat is being used as a stack op(all source memories dims on - // concat_axis are equal to 1), then it may choose a non-optimal memory - // format tag for destination, because concat primitive is chosing it based - // on source memory descriptors and f.e.200x1x10 can be described as both - // abc and bac and both would be using exact same physical layout, but in - // that scenario bac will be chosen for destination no matter which - // formats are being set in inputs. In that scenario we are enforcing using - // a dense format, because it is the most common one and should be the best - // in terms of the performance - const auto src0_tz = srcs_md[0].dims(); - if (std::find(src0_tz.begin(), src0_tz.end(), 1) != src0_tz.end()) { - dst_md = memory::desc( - dst_dims, dt, platform::GetPlainMKLDNNFormat(dst_dims.size())); - } else { - dst_md = memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); - } + dnnl::memory::desc dst_md = + memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); }