From f56f187f6e2b3b8e15bcddf81a3f57fccb5688ae Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Jun 2020 20:23:04 +0800 Subject: [PATCH] fix(mbg/gopt): fix nchw44-dot channel wise trans to nchw44 GitOrigin-RevId: aa2059a79601822131a29f8e23f6808847ca62c3 --- .../arm_common/conv_bias_multi_thread.cpp | 25 ++++--- src/gopt/impl/tensor_reformat.cpp | 69 ++++++++++--------- src/gopt/test/inference.cpp | 2 +- 3 files changed, 56 insertions(+), 40 deletions(-) diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 2ca1c8309..6f4fae34b 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -168,10 +168,10 @@ std::vector get_nchw44_conv_bias_args( for (auto nlmode : nonlinemode) for (size_t n : {1, 2}) for (size_t kernel : kernel_vec) - for (size_t oc : {4, 12, 32}) + for (size_t oc : {4, 32}) for (size_t ic : {1, 3, 4, 12, 32}) for (size_t h : {3, 5, 12}) - for (size_t w : {7, 16, 23}) { + for (size_t w : {7, 23}) { for (size_t group = 1; group <= std::min(oc, ic); ++group) { pack(n, oc, ic, h, w, kernel, stride, @@ -350,13 +350,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), "F32DIRECT_SMALL_GROUP"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) { - check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false, +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { + check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, false, false, + false, false, false), + handle(), "F32_CONV_NCHW44_DIRECT"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { + check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, false, true, true), handle(), "F32_CONV_NCHW44_DIRECT"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) { - check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false, +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { + check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, false, true, true), handle(), "F32_CONV_NCHW44_DIRECT"); } @@ -388,11 +393,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { false, true), handle(), "F32_CONV_NCHW_NCHW44"); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) { +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { check_conv_bias( - get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false), + get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), "F32_CHANNEL_WISE_NCHW44"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { + check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false), + handle(), "F32_CHANNEL_WISE_NCHW44"); +} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { check_conv_bias( diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 0b27dbcc2..1b82f409b 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -2050,23 +2050,27 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { //! First is whether the conv can trans to nchwxx, second is the filter //! trans mode using RelayoutMode = RelayoutPlaceholder::LayoutType; - using TestTransResult = std::pair; - megdnn::param::ConvolutionV0::Format conv_dot_format = - megdnn::param::ConvBias::Format::NCHW44_DOT; + struct TestTransResult { + TransType trans_type; + RelayoutMode relayout_mod; + megdnn::param::ConvolutionV0::Format conv_format; + }; constexpr size_t pack_c_size = 4_z; auto test_trans_nchw44_dot = [](const megdnn::param::Convolution::Sparse conv_mode, const VarNode* filter) -> TestTransResult { - TestTransResult ret{TransType::TRANS_NONE, {}}; + TestTransResult ret{TransType::TRANS_NONE, {}, {}}; if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { size_t IC = filter->shape()[1]; size_t OC = filter->shape()[0]; if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { - ret.first = TransType::TRANS_PURE_NCHWXX; - ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; + ret.trans_type = TransType::TRANS_PURE_NCHWXX; + ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; } else if (IC < pack_c_size && OC % pack_c_size == 0) { - ret.first = TransType::TRANS_HYBIRD_NCHWXX; - ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; + ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; + ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; } } else { mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); @@ -2074,15 +2078,18 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { size_t ocpg = filter->shape()[1]; size_t icpg = filter->shape()[2]; if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { - ret.first = TransType::TRANS_NONE; + ret.trans_type = TransType::TRANS_PURE_NCHWXX; + ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { - ret.first = TransType::TRANS_PURE_NCHWXX; - ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; + ret.trans_type = TransType::TRANS_PURE_NCHWXX; + ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; + ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; } } return ret; }; - auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format]( + auto replace_conv_opr = [test_trans_nchw44_dot]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -2094,7 +2101,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { auto is_trans = test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); //! can not trans to nchwxx - if (is_trans.first == TransType::TRANS_NONE) { + if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, "The origin filter is not NCHW mode"); @@ -2108,14 +2115,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); return new_opr; - } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { //! filter trans to nchwxx mode mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, "The origin filter is not NCHW mode"); VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], + is_trans.relayout_mod); conv_filter = new_filter.node(); //! src trans to nchwxx mode if (new_inp[0]->shape().ndim != 5) { @@ -2125,7 +2132,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { conv_src = new_src.node(); } auto new_param = conv_opr.param(); - new_param.format = conv_dot_format; + new_param.format = is_trans.conv_format; mgb_assert(conv_src->shape().ndim == 5 && conv_filter->shape().ndim >= 6, "The conv src dim is not trans to nchwxx"); @@ -2137,16 +2144,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { "The conv dst dim is not trans to nchwxx"); return new_opr; } else { - mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], + is_trans.relayout_mod); conv_filter = new_filter.node(); mgb_assert(conv_src->shape().ndim == 4 && conv_filter->shape().ndim == 5, "The src and filter is OK"); auto new_param = conv_opr.param(); - new_param.format = conv_dot_format; + new_param.format = is_trans.conv_format; auto new_conv_opr = opr::Convolution::make( conv_src, conv_filter, new_param, conv_opr.execution_policy(), conv_opr.config()); @@ -2157,7 +2164,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { } }; - auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format]( + auto replace_conv_bias_opr = [test_trans_nchw44_dot]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -2168,7 +2175,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { auto is_trans = test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); //! can not trans to nchwxx - if (is_trans.first == TransType::TRANS_NONE) { + if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, "The origin filter is not NCHW mode"); @@ -2188,15 +2195,15 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); return new_opr; - } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], *conv_bias_bias = new_inp[2]; //! filter trans to nchwxx mode mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, "The origin filter is not NCHW mode"); - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], + is_trans.relayout_mod); conv_bias_filter = new_filter.node(); //! src trans to nchwxx mode if (new_inp[0]->shape().ndim != 5) { @@ -2213,7 +2220,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { } auto new_param = conv_bias_opr.param(); - new_param.format = conv_dot_format; + new_param.format = is_trans.conv_format; mgb_assert(conv_bias_src->shape().ndim == 5 && conv_bias_filter->shape().ndim >= 6, "The conv_bias src dim is not trans to nchwxx"); @@ -2225,11 +2232,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { "The conv_bias dst dim is not trans to nchwxx"); return new_opr; } else { - mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], *conv_bias_bias = new_inp[2]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], + is_trans.relayout_mod); conv_bias_filter = new_filter.node(); //! bias trans to nchwxx mode, bias may be scale if (new_inp[2]->shape().ndim == 4) { @@ -2242,7 +2249,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { mgb_assert((conv_bias_bias->shape().ndim == 5) || conv_bias_bias->shape().is_scalar()); auto new_param = conv_bias_opr.param(); - new_param.format = conv_dot_format; + new_param.format = is_trans.conv_format; auto new_conv_bias_opr = opr::ConvBias::make( conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index ce58a0ed0..03bf1bc1c 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2694,7 +2694,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, find_opr(y_opt).param().format); - ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, find_opr(y_opt).param().format); graph->compile({{y_opt, {}}}) -- GitLab