diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index 6256767fa43ac523b8e5c702328dc61d9b0ebf45..f6d03cc23d56f8ae25f22b5b2667ed451ef8afaa 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -351,14 +351,23 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, for (auto* quantized_node : quantized_nodes) { // Save quantization info in op_info attr auto op_info = *quantized_node->stmt()->op_info(); + op_info.SetAttr("bit_length", bit_length); + std::string argname; int index; op_info.GetInputArgname(output_act_name, &argname); op_info.GetInputIndex(output_act_name, &index); op_info.SetAttr(argname + std::to_string(index) + "_input_scale", scale_value); - op_info.SetAttr("input_scale", scale_value); // Save it for now - op_info.SetAttr("bit_length", bit_length); + std::string op_type = op_info.Type(); + // Analyse the weight scale or input scale. + if (((op_type == "conv2d" || op_type == "depthwise_conv2d") && + argname == "Input") || + ((op_type == "mul" || op_type == "matmul") && argname == "Y")) { + op_info.SetAttr("weight_scale", scale_value); + } else { + op_info.SetAttr("input_scale", scale_value); + } op_info.UpdateAllInputs(output_act_name, input_act_name); quantized_node->stmt()->ResetOp(op_info, graph->valid_places());