diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 837d4357737a265db9311c99ac5e79a3064fcf3c..b16576505dfd3f8023f039614ccf35b14364244e 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -89,7 +89,8 @@ class ConcatMKLDNNHandler // formats are being set in inputs. In that scenario we are enforcing using // a dense format, because it is the most common one and should be the best // in terms of the performance - if (dst_dims[concat_axis] == static_cast(srcs_md.size())) { + const auto src0_tz = srcs_md[0].dims(); + if (std::find(src0_tz.begin(), src0_tz.end(), 1) != src0_tz.end()) { dst_md = memory::desc( dst_dims, dt, platform::GetPlainMKLDNNFormat(dst_dims.size())); } else {