From 80b086096098d190f30d7c47aa62a0ca0b342e93 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Mon, 14 Sep 2020 05:49:23 +0200 Subject: [PATCH] - Fix to concat oneDNN overwritting data (#27274) test=develop --- paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 436bbdc553..61a9591143 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -84,8 +84,10 @@ class ConcatPrimitiveFactory { concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, Tensor* output, platform::CPUPlace place, const mkldnn::engine& mkldnn_engine) { - dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine, - output->mutable_data(place)); + dst_mem = mkldnn::memory( + concat_pd.dst_desc(), mkldnn_engine, + output->mutable_data(place, concat_pd.dst_desc().get_size())); + return concat(concat_pd); } @@ -182,7 +184,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { prim_creator.SetSrcDataHandleByIndex( *srcs, i, to_void_cast(multi_input[i]->data())); } - prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data(place)); + prim_creator.SetDstDataHandle( + *dst_mem, + output->mutable_data(place, concat_pd->dst_desc().get_size())); } mkldnn::stream astream(mkldnn_engine); -- GitLab