未验证 提交 4582f697 编写于 作者: J Jacek Czaja 提交者: GitHub

- Fix to concat oneDNN overwritting data (#27273)

test=develop
上级 c23f09fe
...@@ -86,8 +86,10 @@ class ConcatPrimitiveFactory { ...@@ -86,8 +86,10 @@ class ConcatPrimitiveFactory {
concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place, Tensor* output, platform::CPUPlace place,
const mkldnn::engine& mkldnn_engine) { const mkldnn::engine& mkldnn_engine) {
dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine, dst_mem = mkldnn::memory(
output->mutable_data<T>(place)); concat_pd.dst_desc(), mkldnn_engine,
output->mutable_data<T>(place, concat_pd.dst_desc().get_size()));
return concat(concat_pd); return concat(concat_pd);
} }
...@@ -193,7 +195,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -193,7 +195,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
prim_creator.SetSrcDataHandleByIndex( prim_creator.SetSrcDataHandleByIndex(
*srcs, i, to_void_cast<T>(multi_input[i]->data<T>())); *srcs, i, to_void_cast<T>(multi_input[i]->data<T>()));
} }
prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place)); prim_creator.SetDstDataHandle(
*dst_mem,
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
} }
mkldnn::stream astream(mkldnn_engine); mkldnn::stream astream(mkldnn_engine);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册