提交 1d860f4d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/x86): fix dnnl int8 algo on vnni

GitOrigin-RevId: 2384e095584c1720585d974231b47131ae6b18e6
上级 871e6a51
......@@ -269,16 +269,11 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle(
#define REORDER_MEMORY(megdnn_memory, reorder_memory) \
do { \
if (megdnn_memory.get_desc() != conv_prim_desc.src_desc()) { \
reorder_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn); \
auto reorder_pd = reorder::primitive_desc( \
eng_mkldnn, megdnn_memory.get_desc(), eng_mkldnn, \
reorder_memory.get_desc()); \
auto reorder_exe = reorder(reorder_pd); \
reorder_exe.execute(stream_mkldnn, megdnn_memory, reorder_memory); \
} else { \
reorder_memory = megdnn_memory; \
} \
} while (0)
void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
......@@ -340,7 +335,10 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
auto conv = convolution_forward(conv_prim_desc);
memory conv_src_memory, conv_weight_memory, conv_dst_memory;
memory conv_src_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn);
memory conv_weight_memory =
memory(conv_prim_desc.weights_desc(), eng_mkldnn);
memory conv_dst_memory;
REORDER_MEMORY(megdnn_src_memory, conv_src_memory);
REORDER_MEMORY(megdnn_weight_memory, conv_weight_memory);
......@@ -354,7 +352,7 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
conv.execute(stream_mkldnn, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weight_memory},
{DNNL_ARG_DST, conv_dst_memory}});
REORDER_MEMORY(megdnn_dst_memory, conv_dst_memory);
REORDER_MEMORY(conv_dst_memory, megdnn_dst_memory);
stream_mkldnn.wait();
} else {
std::vector<primitive> net;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册