diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 3cafb0e9fc6147626f066bbeba1b10d074a37b87..b2815cbdc65b53beba9cdb1864d10875d5db5e62 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -86,8 +86,10 @@ class ConcatPrimitiveFactory { concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, Tensor* output, platform::CPUPlace place, const mkldnn::engine& mkldnn_engine) { - dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine, - output->mutable_data(place)); + dst_mem = mkldnn::memory( + concat_pd.dst_desc(), mkldnn_engine, + output->mutable_data(place, concat_pd.dst_desc().get_size())); + return concat(concat_pd); } @@ -193,7 +195,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { prim_creator.SetSrcDataHandleByIndex( *srcs, i, to_void_cast(multi_input[i]->data())); } - prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data(place)); + prim_creator.SetDstDataHandle( + *dst_mem, + output->mutable_data(place, concat_pd->dst_desc().get_size())); } mkldnn::stream astream(mkldnn_engine);