未验证 提交 80b08609 编写于 作者: J Jacek Czaja 提交者: GitHub

- Fix to concat oneDNN overwritting data (#27274)

test=develop
上级 bcf4a49b
......@@ -84,8 +84,10 @@ class ConcatPrimitiveFactory {
concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place,
const mkldnn::engine& mkldnn_engine) {
dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine,
output->mutable_data<T>(place));
dst_mem = mkldnn::memory(
concat_pd.dst_desc(), mkldnn_engine,
output->mutable_data<T>(place, concat_pd.dst_desc().get_size()));
return concat(concat_pd);
}
......@@ -182,7 +184,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
prim_creator.SetSrcDataHandleByIndex(
*srcs, i, to_void_cast<T>(multi_input[i]->data<T>()));
}
prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place));
prim_creator.SetDstDataHandle(
*dst_mem,
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
}
mkldnn::stream astream(mkldnn_engine);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册