From 627091dfeec9f6e87c954bbe785764aaab50e861 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 7 Jul 2020 17:47:28 +0800 Subject: [PATCH] Support quantizing conv2d_transpose, test=develop (#3893) --- .../mir/fusion/quant_dequant_fuse_pass.cc | 5 ++- .../core/mir/fusion/quant_dequant_op_fuser.cc | 38 ++++++++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 80a033c75f..ea8400b0bb 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -34,12 +34,13 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { } // fuse quantized node and dequant node - for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) { + for (auto& op_type : + {"conv2d", "mul", "depthwise_conv2d", "conv2d_transpose"}) { fusion::DequantOpFuser fuser(op_type); fuser(graph.get()); } - for (auto& op_type : {"conv2d", "depthwise_conv2d"}) { + for (auto& op_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { fusion::ChannelWiseDequantOpFuser fuser(op_type); fuser(graph.get()); } diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 30a887c685..1335518b00 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -23,6 +23,20 @@ namespace lite { namespace mir { namespace fusion { +static std::string GetWeightArgname(const std::string& op_type) { + std::string weight_argname{}; + std::vector conv_ops = { + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; + std::vector mul_ops = {"mul", "matmul"}; + if (std::find(conv_ops.begin(), conv_ops.end(), op_type) != conv_ops.end()) { + weight_argname = "Filter"; + } else if (std::find(mul_ops.begin(), mul_ops.end(), op_type) != + mul_ops.end()) { + weight_argname = "Y"; + } + return weight_argname; +} + void DeleteQuantOpFuser::BuildPattern() { auto* input_scale_node = VarNode("input_scale_node") ->assert_is_op_input(quant_op_type_, "InScale"); @@ -83,20 +97,13 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) { } void DequantOpFuser::BuildPattern() { - std::string weight_name = ""; - if (quantized_op_type_ == "conv2d" || - quantized_op_type_ == "depthwise_conv2d") { - weight_name = "Filter"; - } else { - weight_name = "Y"; - } - + std::string weight_argname = GetWeightArgname(quantized_op_type_); auto* quantized_op_input = VarNode("quantized_op_input") ->assert_is_op_input(quantized_op_type_) ->AsInput(); auto* quantized_op_weight = VarNode("quantized_op_weight") - ->assert_is_op_input(quantized_op_type_, weight_name) + ->assert_is_op_input(quantized_op_type_, weight_argname) ->AsInput(); auto* quantized_op = OpNode("quantized_op", quantized_op_type_) ->assert_is_op(quantized_op_type_) @@ -152,7 +159,8 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph, std::vector weight_scale; int weight_scale_size = 0; if (quantized_op_type_ == "conv2d" || - quantized_op_type_ == "depthwise_conv2d") { + quantized_op_type_ == "depthwise_conv2d" || + quantized_op_type_ == "conv2d_transpose") { op_desc.SetInput("Input", {quantized_op_input->arg()->name}); op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should @@ -199,12 +207,13 @@ cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) { void ChannelWiseDequantOpFuser::BuildPattern() { std::string dequant_op_type = "fake_channel_wise_dequantize_max_abs"; + std::string weight_argname = GetWeightArgname(quantized_op_type_); auto* quantized_op_input = VarNode("quantized_op_input") ->assert_is_op_input(quantized_op_type_) ->AsInput(); auto* quantized_op_weight = VarNode("quantized_op_weight") - ->assert_is_op_input(quantized_op_type_, "Filter") + ->assert_is_op_input(quantized_op_type_, weight_argname) ->AsInput(); auto* quantized_op = OpNode("quantized_op", quantized_op_type_) ->assert_is_op(quantized_op_type_) @@ -263,14 +272,17 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph, // set op desc auto op_desc = *quantized_op->stmt()->op_info(); if (quantized_op_type_ == "conv2d" || - quantized_op_type_ == "depthwise_conv2d") { + quantized_op_type_ == "depthwise_conv2d" || + quantized_op_type_ == "conv2d_transpose") { op_desc.SetInput("Input", {quantized_op_input->arg()->name}); op_desc.SetOutput("Output", {dequant_op_out->arg()->name}); } else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") { op_desc.SetInput("X", {quantized_op_input->arg()->name}); op_desc.SetOutput("Out", {dequant_op_out->arg()->name}); } - op_desc.SetAttr("enable_int8", true); + if (quantized_op_type_ != "conv2d_transpose") { + op_desc.SetAttr("enable_int8", true); + } op_desc.SetInputScale(weight_name, weight_scale); // change the weight from the float type to int8 type. -- GitLab