diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass.cc b/lite/core/mir/fusion/conv_conv_fuse_pass.cc index d277da87689d7aa1f21ef260013b6e81f2146a09..b2c5d8d15ab95fbcc43adc01c4189ae83b1316ed 100644 --- a/lite/core/mir/fusion/conv_conv_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/core/mir/fusion/conv_conv_fuse_pass.h" +#include #include #include #include "lite/core/mir/fusion/conv_conv_fuser.h" @@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; - bool has_fp32 = false; bool has_int8 = false; + bool has_weight_quant = false; for (auto& place : graph->valid_places()) { if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) { - if (place.precision == PRECISION(kFloat)) { - has_fp32 = true; - } if (place.precision == PRECISION(kInt8)) { has_int8 = true; } @@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr& graph) { return; } } + const std::list& nodes = graph->nodes(); + for (auto& node : nodes) { + if (node.IsStmt()) { + auto* op_info = (node.stmt())->op_info(); + if (op_info->HasAttr("quantization_type")) { + has_weight_quant = true; + break; + } + } + } // only support arm-fp32 - if (has_int8 || (has_fp32 && has_int8)) { + if (has_int8 || has_weight_quant) { return; } // only support fp32 fusion