diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 60675bf84886398fb2b56d3e7e10b4dc69517a54..068a50a1dc0e9ad1c15e0e98f5878693eeeb9f55 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -31,6 +31,7 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddInput("Iter") .IsTensor() + .IsOptional() .End() .AddOutput("Out") .IsTensor() @@ -40,6 +41,7 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddOutput("OutScales") .IsTensor() + .IsOptional() .End() .AddAttr("window_size") .IsType() @@ -167,6 +169,26 @@ QuantDequantFusePass::QuantDequantFusePass() { .AddAttr("y_num_col_dims") .IsNumEQ(1) .End(); + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); AddOpCompat(OpCompat("fc")) .AddInput("Input") .IsTensor() @@ -291,7 +313,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, quantized_op_type == "fc" || quantized_op_type == "conv2d_transpose") { op_desc->SetAttr("Input_scale", scale_value); - } else if (quantized_op_type == "mul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { op_desc->SetAttr("X_scale", scale_value); } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -323,7 +345,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, quantized_op_type == "conv2d_transpose") { weight_name = "Filter"; input_name = "Input"; - } else if (quantized_op_type == "mul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { weight_name = "Y"; input_name = "X"; } else if (quantized_op_type == "fc") { @@ -332,7 +354,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } else { PADDLE_THROW(platform::errors::Unimplemented( "QuantDequantFuse: We only support conv2d, conv2d_fusion, " - "conv2d_transpose, fc, mul for " + "conv2d_transpose, fc, mul, matmul for " "now.")); } const std::string pattern_name = "dequant_fuse"; @@ -410,12 +432,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, // If quantized op is fc, weight scale size = 1; // If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d_transpose, weight scale size = weight dims[1] - if (quantized_op_type == "mul" || quantized_op_type == "fc") { + if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "fc") { if (dequant_type == "fake_dequantize_max_abs") { PADDLE_ENFORCE_EQ( weight_scale.size(), 1, platform::errors::InvalidArgument( - "mul op weight dequantized by [fake_dequantize_max_abs] " + "mul/matmul op weight dequantized by [fake_dequantize_max_abs] " "requires weight scale size = 1, but got %d.", weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { @@ -426,9 +449,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument( - "mul op weight dequantized by " + "mul/matmul op weight dequantized by " "[fake_channel_wise_dequantize_max_abs] requires weight scale " - "size = 2nd dim of mul's weight, which is %d, but got %d.", + "size = 2nd dim of mul/matmul's weight, which is %d, but got " + "%d.", static_cast(w_dims[1]), weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; @@ -493,7 +517,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } else if (quantized_op_type == "fc") { new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetOutput("Out", {new_output}); - } else if (quantized_op_type == "mul") { + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") { new_op_desc.SetInput("X", {new_input}); new_op_desc.SetOutput("Out", {new_output}); } @@ -520,7 +544,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; std::unordered_set quantized_op_types = { - "conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"}; + "conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"}; auto* scope = param_scope(); for (auto& quant_type : quant_types) {