未验证 提交 a59f215d 编写于 作者: W Wangzheee 提交者: GitHub

add quant_dequant_matmul (#34359)

上级 68b4a2c3
...@@ -31,6 +31,7 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -31,6 +31,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddInput("Iter") .AddInput("Iter")
.IsTensor() .IsTensor()
.IsOptional()
.End() .End()
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
...@@ -40,6 +41,7 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -40,6 +41,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddOutput("OutScales") .AddOutput("OutScales")
.IsTensor() .IsTensor()
.IsOptional()
.End() .End()
.AddAttr("window_size") .AddAttr("window_size")
.IsType<int>() .IsType<int>()
...@@ -167,6 +169,26 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -167,6 +169,26 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("y_num_col_dims") .AddAttr("y_num_col_dims")
.IsNumEQ(1) .IsNumEQ(1)
.End(); .End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc")) AddOpCompat(OpCompat("fc"))
.AddInput("Input") .AddInput("Input")
.IsTensor() .IsTensor()
...@@ -291,7 +313,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -291,7 +313,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "fc" || quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") { quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value); op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
op_desc->SetAttr("X_scale", scale_value); op_desc->SetAttr("X_scale", scale_value);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -323,7 +345,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -323,7 +345,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_transpose") { quantized_op_type == "conv2d_transpose") {
weight_name = "Filter"; weight_name = "Filter";
input_name = "Input"; input_name = "Input";
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
weight_name = "Y"; weight_name = "Y";
input_name = "X"; input_name = "X";
} else if (quantized_op_type == "fc") { } else if (quantized_op_type == "fc") {
...@@ -332,7 +354,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -332,7 +354,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, " "QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul for " "conv2d_transpose, fc, mul, matmul for "
"now.")); "now."));
} }
const std::string pattern_name = "dequant_fuse"; const std::string pattern_name = "dequant_fuse";
...@@ -410,12 +432,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -410,12 +432,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
// If quantized op is fc, weight scale size = 1; // If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1] // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "fc") { if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") { if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), 1, weight_scale.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul op weight dequantized by [fake_dequantize_max_abs] " "mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.", "requires weight scale size = 1, but got %d.",
weight_scale.size())); weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
...@@ -426,9 +449,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -426,9 +449,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]), weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul op weight dequantized by " "mul/matmul op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale " "[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul's weight, which is %d, but got %d.", "size = 2nd dim of mul/matmul's weight, which is %d, but got "
"%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size())); static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
...@@ -493,7 +517,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -493,7 +517,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "fc") { } else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output}); new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
new_op_desc.SetInput("X", {new_input}); new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output}); new_op_desc.SetOutput("Out", {new_output});
} }
...@@ -520,7 +544,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -520,7 +544,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = { std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"}; "conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"};
auto* scope = param_scope(); auto* scope = param_scope();
for (auto& quant_type : quant_types) { for (auto& quant_type : quant_types) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册