From ee2e2b3c7b0186663cb27b9a5d9cfd5330bef76b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 22 Sep 2020 20:31:48 +0800 Subject: [PATCH] fix(mgb/gopt): fix nchwxx optpass of no handle conv_bias opr which with no bias GitOrigin-RevId: b2b053add464540c22a61e72f078950c18bf92b0 --- src/gopt/impl/tensor_reformat.cpp | 148 +++++++++++++++++++++--------- src/gopt/test/inference.cpp | 6 +- 2 files changed, 106 insertions(+), 48 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index eebe13d8..debac210 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid( auto& src_node = new_inp[0]; auto& filter_node = new_inp[1]; auto dst_node = opr.output(0); - if (filter_node->shape().ndim != 4) { + //! already transformed or have fuse Z + if (filter_node->shape().ndim != 4 || new_inp.size() == 4) { return false; } megdnn::ConvolutionBase::CanonizedFilterMeta fm; @@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid( megdnn::ConvBiasForward::BiasMode bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; - if (std::is_same::value) { + if (std::is_same::value && + new_inp.size() > 2) { TensorShape bias_shape = new_inp[2]->shape(); if (bias_shape.ndim == 5) { bias_shape = nchwxx_shape_2_nchw_shape(bias_shape); @@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { pack_c_size](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size() <= 3, + "nchwxx does not support conv_bias fuse Z right now"); auto& conv_bias_opr = opr->cast_final_safe(); mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, @@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { temp_inp[0] = new_src.node(); } //! the bias is nchwxx - if (temp_inp[2]->shape().ndim == 5) { + if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { auto new_bias = RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); temp_inp[2] = new_bias.node(); @@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { return new_opr; } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; + *conv_bias_bias = nullptr; //! filter trans to nchwxx mode mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, @@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { src_to_nchwxx_mode); conv_bias_src = new_src.node(); } - //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make(new_inp[2], - src_to_nchwxx_mode); - conv_bias_bias = new_bias.node(); + //! bias trans to nchwxx mode + if (new_inp.size() > 2) { + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], src_to_nchwxx_mode); + conv_bias_bias = new_bias.node(); + } else { + mgb_assert(new_inp[2]->shape().ndim == 5); + conv_bias_bias = new_inp[2]; + } } - auto new_param = conv_bias_opr.param(); new_param.format = conv_bias_format; mgb_assert(conv_bias_src->shape().ndim == 5 && conv_bias_filter->shape().ndim >= 6, "The conv_bias src dim is not trans to nchwxx"); - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); + SymbolVar new_conv_bias_opr; + if (conv_bias_bias) { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, + new_param, conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } else { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5, "The conv_bias dst dim is not trans to nchwxx"); @@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { } else { mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; + *conv_bias_bias = nullptr; auto new_filter = RelayoutPlaceholder::make(new_inp[1], is_trans.second); conv_bias_filter = new_filter.node(); //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make(new_inp[2], - src_to_nchwxx_mode); - conv_bias_bias = new_bias.node(); + if (new_inp.size() > 2) { + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], src_to_nchwxx_mode); + conv_bias_bias = new_bias.node(); + } else { + mgb_assert(new_inp[2]->shape().ndim == 5); + conv_bias_bias = new_inp[2]; + } } mgb_assert(conv_bias_src->shape().ndim == 4 && conv_bias_filter->shape().ndim == 5); - mgb_assert((conv_bias_bias->shape().ndim == 5) || - conv_bias_bias->shape().is_scalar()); auto new_param = conv_bias_opr.param(); new_param.format = conv_bias_format; - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); + SymbolVar new_conv_bias_opr; + if (conv_bias_bias) { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, + new_param, conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } else { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5, "The conv dst dim is not trans to nchwxx"); @@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Argmax::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::ImmutableTensor::typeinfo()] = relayout_inp_to_nchw; } std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( @@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size() <= 3, + "nchwxx-dot does not support conv_bias fuse Z right now"); auto& conv_bias_opr = opr->cast_final_safe(); mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, @@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { } //! the bias is nchwxx - if (temp_inp[2]->shape().ndim == 5) { + if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { auto new_bias = RelayoutPlaceholder::make( new_inp[2], RelayoutMode::NCHW4_TO_NCHW); temp_inp[2] = new_bias.node(); @@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { return new_opr; } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; + *conv_bias_bias = nullptr; //! filter trans to nchwxx mode mgb_assert(new_inp[1]->shape().ndim == 4 || new_inp[1]->shape().ndim == 5, @@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { new_inp[0], RelayoutMode::NCHW_TO_NCHW4); conv_bias_src = new_src.node(); } - //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make( - new_inp[2], RelayoutMode::NCHW_TO_NCHW4); - conv_bias_bias = new_bias.node(); + //! bias trans to nchwxx mode + if (new_inp.size() > 2) { + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + conv_bias_bias = new_bias.node(); + } else { + mgb_assert(new_inp[2]->shape().ndim == 5); + conv_bias_bias = new_inp[2]; + } } - auto new_param = conv_bias_opr.param(); new_param.format = is_trans.conv_format; mgb_assert(conv_bias_src->shape().ndim == 5 && conv_bias_filter->shape().ndim >= 6, "The conv_bias src dim is not trans to nchwxx"); - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); + SymbolVar new_conv_bias_opr; + if (conv_bias_bias) { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, + new_param, conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } else { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5, "The conv_bias dst dim is not trans to nchwxx"); @@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { } else { mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; + *conv_bias_bias = nullptr; auto new_filter = RelayoutPlaceholder::make(new_inp[1], is_trans.relayout_mod); conv_bias_filter = new_filter.node(); //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make( - new_inp[2], RelayoutMode::NCHW_TO_NCHW4); - conv_bias_bias = new_bias.node(); + if (new_inp.size() > 2) { + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + conv_bias_bias = new_bias.node(); + } else { + mgb_assert(new_inp[2]->shape().ndim == 5); + conv_bias_bias = new_inp[2]; + } } mgb_assert(conv_bias_src->shape().ndim == 4 && conv_bias_filter->shape().ndim == 5); - mgb_assert((conv_bias_bias->shape().ndim == 5) || - conv_bias_bias->shape().is_scalar()); auto new_param = conv_bias_opr.param(); new_param.format = is_trans.conv_format; - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); + SymbolVar new_conv_bias_opr; + if (conv_bias_bias) { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, + new_param, conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } else { + new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), + conv_bias_opr.config()); + } OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5, "The conv dst dim is not trans to nchwxx"); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 05892bd8..6faa9ac6 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3009,9 +3009,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { //! no supported hybrid nchw44 opr::ConvBias::Param param_conv_bias_pad0; param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; - auto b1 = mkcvar("b1", {1, 8, 1, 1}); auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); - auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, + auto conv1_f1 = opr::ConvBias::make(x, w1_f1, param_conv_bias_pad0, {}, OperatorNodeConfig("conv1_f1")); auto conv1_add = conv1_f1 * conv1; @@ -3263,9 +3262,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); - auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); auto conv_1_2 = opr::ConvBias::make( - conv_1_q8, w1_2, b1_2, param_conv_bias, {}, + conv_1_q8, w1_2, param_conv_bias, {}, OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); -- GitLab