From 339c2e53b3cb33212752ca8fcdcc357e6ee1a4e4 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 11 Sep 2020 09:51:24 +0800 Subject: [PATCH] Weight quantization skip conv_conv_fuse_pass, test=develop (#4292) --- lite/core/mir/fusion/conv_conv_fuse_pass.cc | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass.cc b/lite/core/mir/fusion/conv_conv_fuse_pass.cc index d277da8768..b2c5d8d15a 100644 --- a/lite/core/mir/fusion/conv_conv_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/core/mir/fusion/conv_conv_fuse_pass.h" +#include #include #include #include "lite/core/mir/fusion/conv_conv_fuser.h" @@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; - bool has_fp32 = false; bool has_int8 = false; + bool has_weight_quant = false; for (auto& place : graph->valid_places()) { if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) { - if (place.precision == PRECISION(kFloat)) { - has_fp32 = true; - } if (place.precision == PRECISION(kInt8)) { has_int8 = true; } @@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr& graph) { return; } } + const std::list& nodes = graph->nodes(); + for (auto& node : nodes) { + if (node.IsStmt()) { + auto* op_info = (node.stmt())->op_info(); + if (op_info->HasAttr("quantization_type")) { + has_weight_quant = true; + break; + } + } + } // only support arm-fp32 - if (has_int8 || (has_fp32 && has_int8)) { + if (has_int8 || has_weight_quant) { return; } // only support fp32 fusion -- GitLab