diff --git a/dnn/src/x86/conv_bias/int8/algos.cpp b/dnn/src/x86/conv_bias/int8/algos.cpp index 1e2a059bb504fff5643c66eaac46a4d05c889715..ba3aeb2c56bf637746a97d7b233cb1c5e9f625f3 100644 --- a/dnn/src/x86/conv_bias/int8/algos.cpp +++ b/dnn/src/x86/conv_bias/int8/algos.cpp @@ -267,18 +267,13 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle( } } -#define REORDER_MEMORY(megdnn_memory, reorder_memory) \ - do { \ - if (megdnn_memory.get_desc() != conv_prim_desc.src_desc()) { \ - reorder_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn); \ - auto reorder_pd = reorder::primitive_desc( \ - eng_mkldnn, megdnn_memory.get_desc(), eng_mkldnn, \ - reorder_memory.get_desc()); \ - auto reorder_exe = reorder(reorder_pd); \ - reorder_exe.execute(stream_mkldnn, megdnn_memory, reorder_memory); \ - } else { \ - reorder_memory = megdnn_memory; \ - } \ +#define REORDER_MEMORY(megdnn_memory, reorder_memory) \ + do { \ + auto reorder_pd = reorder::primitive_desc( \ + eng_mkldnn, megdnn_memory.get_desc(), eng_mkldnn, \ + reorder_memory.get_desc()); \ + auto reorder_exe = reorder(reorder_pd); \ + reorder_exe.execute(stream_mkldnn, megdnn_memory, reorder_memory); \ } while (0) void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( @@ -340,7 +335,10 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( auto conv = convolution_forward(conv_prim_desc); - memory conv_src_memory, conv_weight_memory, conv_dst_memory; + memory conv_src_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn); + memory conv_weight_memory = + memory(conv_prim_desc.weights_desc(), eng_mkldnn); + memory conv_dst_memory; REORDER_MEMORY(megdnn_src_memory, conv_src_memory); REORDER_MEMORY(megdnn_weight_memory, conv_weight_memory); @@ -354,7 +352,7 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( conv.execute(stream_mkldnn, {{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weight_memory}, {DNNL_ARG_DST, conv_dst_memory}}); - REORDER_MEMORY(megdnn_dst_memory, conv_dst_memory); + REORDER_MEMORY(conv_dst_memory, megdnn_dst_memory); stream_mkldnn.wait(); } else { std::vector net;