diff --git a/src/gopt/impl/folding_conv_dimshuffle.cpp b/src/gopt/impl/folding_conv_dimshuffle.cpp index dac9f07576e8c1da872df8a8c4e6ec94b589e2db..cda3d5d46f22bfb920978a55612351de4f4cb0b5 100644 --- a/src/gopt/impl/folding_conv_dimshuffle.cpp +++ b/src/gopt/impl/folding_conv_dimshuffle.cpp @@ -240,11 +240,30 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { &readers](OperatorNodeBase* opr) { ThinHashSet opr_set; ThinHashSet reader_set; + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + auto in_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && + (out_dtype.enumv() == DTypeEnum::QuantizedS4 || + out_dtype.enumv() == DTypeEnum::Quantized4Asymm); + if (!is_s82s4) + return false; + opr_set.insert(typecvt); + // check reshape - auto reshape = try_cast_as_op(opr); + auto reshape = + try_cast_as_op(typecvt->input(0)->owner_opr()); if (reshape == nullptr) return false; - opr_set.insert(opr); + opr_set.insert(reshape); + for (auto&& i : readers[reshape]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } // check dimshuffle auto shuffle = @@ -267,27 +286,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { } } - auto typecvt = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (typecvt == nullptr) - return false; - auto in_dtype = typecvt->input(0)->dtype(), - out_dtype = typecvt->output(0)->dtype(); - bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && - (out_dtype.enumv() == DTypeEnum::QuantizedS4 || - out_dtype.enumv() == DTypeEnum::Quantized4Asymm); - if (!is_s82s4) - return false; - opr_set.insert(typecvt); - for (auto&& i : readers[typecvt]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check conv bias auto conv_bias = - try_cast_as_op(typecvt->input(0)->owner_opr()); + try_cast_as_op(shuffle->input(0)->owner_opr()); if (conv_bias == nullptr) return false; auto inp_dtype = conv_bias->input(0)->dtype();