diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 91e9d7bbafea5d652d3dd67c682d1b480314a868..436bbdc55362766fe3793c1670c41a03fc8cd6ff 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); ConcatPrimitiveFactory prim_creator; + // If one of the multiple inputs of concat has an input size of 0, the + // actual size of the multi_input will change std::string key = platform::CreateKey( paddle::framework::vectorize(multi_input[0]->dims()), - ctx.OutputName("Out"), dt, platform::ThreadIDasStr()); + multi_input.size(), ctx.OutputName("Out"), dt, + platform::ThreadIDasStr()); const std::string key_prim = key + "@concat_p"; const std::string key_concat_pd = key + "@concat_pd";