From af828ca9ec29a71090d396d6b156ced905bdd499 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 7 Sep 2021 19:26:20 +0800 Subject: [PATCH] feat(mgb/gopt): fix folding conv dimshuffle pass GitOrigin-RevId: 756878c173787481761a48c8d0550aa409702b3b --- src/gopt/impl/folding_conv_dimshuffle.cpp | 43 ++++++++++++----------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/gopt/impl/folding_conv_dimshuffle.cpp b/src/gopt/impl/folding_conv_dimshuffle.cpp index dac9f0757..cda3d5d46 100644 --- a/src/gopt/impl/folding_conv_dimshuffle.cpp +++ b/src/gopt/impl/folding_conv_dimshuffle.cpp @@ -240,11 +240,30 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { &readers](OperatorNodeBase* opr) { ThinHashSet opr_set; ThinHashSet reader_set; + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + auto in_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && + (out_dtype.enumv() == DTypeEnum::QuantizedS4 || + out_dtype.enumv() == DTypeEnum::Quantized4Asymm); + if (!is_s82s4) + return false; + opr_set.insert(typecvt); + // check reshape - auto reshape = try_cast_as_op(opr); + auto reshape = + try_cast_as_op(typecvt->input(0)->owner_opr()); if (reshape == nullptr) return false; - opr_set.insert(opr); + opr_set.insert(reshape); + for (auto&& i : readers[reshape]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } // check dimshuffle auto shuffle = @@ -267,27 +286,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { } } - auto typecvt = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (typecvt == nullptr) - return false; - auto in_dtype = typecvt->input(0)->dtype(), - out_dtype = typecvt->output(0)->dtype(); - bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && - (out_dtype.enumv() == DTypeEnum::QuantizedS4 || - out_dtype.enumv() == DTypeEnum::Quantized4Asymm); - if (!is_s82s4) - return false; - opr_set.insert(typecvt); - for (auto&& i : readers[typecvt]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check conv bias auto conv_bias = - try_cast_as_op(typecvt->input(0)->owner_opr()); + try_cast_as_op(shuffle->input(0)->owner_opr()); if (conv_bias == nullptr) return false; auto inp_dtype = conv_bias->input(0)->dtype(); -- GitLab