diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 436bbdc55362766fe3793c1670c41a03fc8cd6ff..61a95911437fc441ea784add867c3041ff2a6885 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -84,8 +84,10 @@ class ConcatPrimitiveFactory { concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, Tensor* output, platform::CPUPlace place, const mkldnn::engine& mkldnn_engine) { - dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine, - output->mutable_data(place)); + dst_mem = mkldnn::memory( + concat_pd.dst_desc(), mkldnn_engine, + output->mutable_data(place, concat_pd.dst_desc().get_size())); + return concat(concat_pd); } @@ -182,7 +184,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { prim_creator.SetSrcDataHandleByIndex( *srcs, i, to_void_cast(multi_input[i]->data())); } - prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data(place)); + prim_creator.SetDstDataHandle( + *dst_mem, + output->mutable_data(place, concat_pd->dst_desc().get_size())); } mkldnn::stream astream(mkldnn_engine);