提交 52854e88 编写于 作者: X xiaolil1

fix int8 init for src data reorder

上级 04804674
...@@ -686,7 +686,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -686,7 +686,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
...@@ -773,6 +773,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -773,6 +773,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else { } else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册