未验证 提交 7647d402 编写于 作者: X XGZhang 提交者: GitHub

Update quant_conv2d_dequant_fuse_pass.cc (#36821)

上级 f20c5c9c
......@@ -210,6 +210,22 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
......@@ -355,7 +371,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -387,7 +404,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter";
input_name = "Input";
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
weight_name = "Y";
input_name = "X";
} else if (quantized_op_type == "fc") {
......@@ -396,7 +414,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul, matmul for "
"conv2d_transpose, fc, mul, matmul, matmul_v2 for "
"now."));
}
const std::string pattern_name = "dequant_fuse";
......@@ -437,7 +455,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale;
int quant_axis = 0;
if (dequant_op_node->Op()->HasAttr("quant_axis")) {
quant_axis =
BOOST_GET_CONST(int, dequant_op_node->Op()->GetAttr("quant_axis"));
}
// Get weight scale
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
Node* dequant_channel_scale_node =
......@@ -475,25 +497,37 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
PADDLE_ENFORCE_EQ(weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[0];
}
}
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul/matmul op weight dequantized by "
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul's weight, which is %d, but got "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"but got "
"%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
......@@ -511,6 +545,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
"model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model again.",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 0, true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
......@@ -528,6 +572,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
......@@ -560,7 +614,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
}
......@@ -587,7 +642,9 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"};
"conv2d", "mul", "matmul", "depthwise_conv2d",
"conv2d_transpose", "fc", "matmul_v2",
};
auto* scope = param_scope();
for (auto& quant_type : quant_types) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册