From 05b5ef29021a1488b9143c22ee255710cbbb8e08 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Sun, 1 Mar 2020 15:11:58 +0800 Subject: [PATCH] Support quantizing softmax op, test=develop (#3051) --- lite/core/mir/fusion/quant_dequant_fuse_pass.cc | 2 +- lite/core/mir/fusion/quant_dequant_op_fuser.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ff5a7a1f25..ab81f3d809 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -45,7 +45,7 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // delete quant_dequant_node - for (auto op_type : {"pool2d", "elementwise_add"}) { + for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) { fusion::DeleteQuantDequantOpFuser fuser(op_type); fuser(graph.get()); } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index da611e4490..7797864a2e 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -297,7 +297,7 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { void DeleteQuantDequantOpFuser::BuildPattern() { std::string quant_dequant_op_type = "fake_quantize_dequantize_moving_average_abs_max"; - if (quantized_op_type_ == "pool2d") { + if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node = VarNode("input_scale_node") ->assert_is_op_input(quant_dequant_op_type, "InScale"); @@ -374,7 +374,7 @@ void DeleteQuantDequantOpFuser::BuildPattern() { void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - if (quantized_op_type_ == "pool2d") { + if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") { auto* input_scale_node = matched.at("input_scale_node"); auto* input_act_node = matched.at("input_act_node"); auto* quant_dequant_node = matched.at("quant_dequant_node"); -- GitLab