diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 1864899b07e0180d5d9f8d32d32a17245e992a81..619fe7ab4f738f44f09a18d8807caf685ce7d1f2 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -210,6 +210,22 @@ QuantDequantFusePass::QuantDequantFusePass() { .AddAttr("y_num_col_dims") .IsNumEQ(1) .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); AddOpCompat(OpCompat("matmul")) .AddInput("X") .IsTensor() @@ -355,7 +371,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, quantized_op_type == "fc" || quantized_op_type == "conv2d_transpose") { op_desc->SetAttr("Input_scale", scale_value); - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { op_desc->SetAttr("X_scale", scale_value); } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -387,7 +404,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, quantized_op_type == "conv2d_transpose") { weight_name = "Filter"; input_name = "Input"; - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { weight_name = "Y"; input_name = "X"; } else if (quantized_op_type == "fc") { @@ -396,7 +414,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } else { PADDLE_THROW(platform::errors::Unimplemented( "QuantDequantFuse: We only support conv2d, conv2d_fusion, " - "conv2d_transpose, fc, mul, matmul for " + "conv2d_transpose, fc, mul, matmul, matmul_v2 for " "now.")); } const std::string pattern_name = "dequant_fuse"; @@ -437,7 +455,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length")); int range = ((1 << (bit_length - 1)) - 1); std::vector weight_scale; - + int quant_axis = 0; + if (dequant_op_node->Op()->HasAttr("quant_axis")) { + quant_axis = + BOOST_GET_CONST(int, dequant_op_node->Op()->GetAttr("quant_axis")); + } // Get weight scale if (dequant_type == "fake_channel_wise_dequantize_max_abs") { Node* dequant_channel_scale_node = @@ -475,25 +497,37 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, // If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d_transpose, weight scale size = weight dims[1] if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "fc") { + quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { if (dequant_type == "fake_dequantize_max_abs") { - PADDLE_ENFORCE_EQ( - weight_scale.size(), 1, - platform::errors::InvalidArgument( - "mul/matmul op weight dequantized by [fake_dequantize_max_abs] " - "requires weight scale size = 1, but got %d.", - weight_scale.size())); + PADDLE_ENFORCE_EQ(weight_scale.size(), 1, + platform::errors::InvalidArgument( + "mul/matmul/matmul_v2 op weight dequantized by " + "[fake_dequantize_max_abs] " + "requires weight scale size = 1, but got %d.", + weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { quantized_weight_data[j] *= weight_scale[0]; } } if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' of mul/matmul/fc/matmul_v2 op weight " + "dequantized by " + "[fake_channel_wise_dequantize_max_abs]should be 1, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument( - "mul/matmul op weight dequantized by " + "mul/matmul/matmul_v2 op weight dequantized by " "[fake_channel_wise_dequantize_max_abs] requires weight scale " - "size = 2nd dim of mul/matmul's weight, which is %d, but got " + "size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " + "but got " "%d.", static_cast(w_dims[1]), weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { @@ -511,6 +545,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, "model, please set the 'weight_quantize_type' params as " "'channel_wise_abs_max' and generate the quantized model again.", dequant_type)); + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 0, true, + platform::errors::InvalidArgument( + "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " + "by [fake_channel_wise_dequantize_max_abs]should be 0, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[0]), platform::errors::InvalidArgument( @@ -528,6 +572,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, "conv2d_transpose must be dequantized by " "[fake_channel_wise_dequantize_max_abs], but got %s", dequant_type)); + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' of conv2d_transpose op weight dequantized by " + "[fake_channel_wise_dequantize_max_abs]should be 1, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument( @@ -560,7 +614,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } else if (quantized_op_type == "fc") { new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetOutput("Out", {new_output}); - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { new_op_desc.SetInput("X", {new_input}); new_op_desc.SetOutput("Out", {new_output}); } @@ -587,7 +642,9 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; std::unordered_set quantized_op_types = { - "conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"}; + "conv2d", "mul", "matmul", "depthwise_conv2d", + "conv2d_transpose", "fc", "matmul_v2", + }; auto* scope = param_scope(); for (auto& quant_type : quant_types) {