From 79edd7420481d27e0c16cd8fecc208051ae492db Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 1 Jun 2020 12:49:53 +0800 Subject: [PATCH] Support dygraph quantized model, test=develop (#3725) --- lite/core/mir/fusion/quant_dequant_op_fuser.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 6256767fa4..f6d03cc23d 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -351,14 +351,23 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, for (auto* quantized_node : quantized_nodes) { // Save quantization info in op_info attr auto op_info = *quantized_node->stmt()->op_info(); + op_info.SetAttr("bit_length", bit_length); + std::string argname; int index; op_info.GetInputArgname(output_act_name, &argname); op_info.GetInputIndex(output_act_name, &index); op_info.SetAttr(argname + std::to_string(index) + "_input_scale", scale_value); - op_info.SetAttr("input_scale", scale_value); // Save it for now - op_info.SetAttr("bit_length", bit_length); + std::string op_type = op_info.Type(); + // Analyse the weight scale or input scale. + if (((op_type == "conv2d" || op_type == "depthwise_conv2d") && + argname == "Input") || + ((op_type == "mul" || op_type == "matmul") && argname == "Y")) { + op_info.SetAttr("weight_scale", scale_value); + } else { + op_info.SetAttr("input_scale", scale_value); + } op_info.UpdateAllInputs(output_act_name, input_act_name); quantized_node->stmt()->ResetOp(op_info, graph->valid_places()); -- GitLab