diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ff5a7a1f25239d9dbfc79491bd137804b16b6cfa..ab81f3d809507dd340056c97a39998c908a75dc7 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -45,7 +45,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // delete quant_dequant_node - for (auto op_type : {"pool2d", "elementwise_add"}) { + for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) { fusion::DeleteQuantDequantOpFuser fuser(op_type); fuser(graph.get()); } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index da611e4490f4ba7268d9011b3dbb391a63a88305..7797864a2e4b75f52fd7da93ea81613a2175f423 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -297,7 +297,7 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { void DeleteQuantDequantOpFuser::BuildPattern() { std::string quant_dequant_op_type = "fake_quantize_dequantize_moving_average_abs_max"; - if (quantized_op_type_ == "pool2d") { + if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node = VarNode("input_scale_node") ->assert_is_op_input(quant_dequant_op_type, "InScale"); @@ -374,7 +374,7 @@ void DeleteQuantDequantOpFuser::BuildPattern() { void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - if (quantized_op_type_ == "pool2d") { + if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node = matched.at("input_scale_node"); auto* input_act_node = matched.at("input_act_node"); auto* quant_dequant_node = matched.at("quant_dequant_node");