未验证 提交 ddc244d3 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is...

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is shared  between ops (#45719)

* fix_old_format

* fix bug in quant_conv2d_dequant

* fix bug in quant_conv2d_dequant
上级 4acf1ef7
...@@ -440,7 +440,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -440,7 +440,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
// Create pattern // Create pattern
patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name); pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name);
// Record whether quantized_op_weight_node has been dealt with
std::unordered_set<Node*> quantized_op_weight_node_set;
// Create new op desc // Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -507,6 +508,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -507,6 +508,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
const auto& w_dims = weight_tensor->dims(); const auto& w_dims = weight_tensor->dims();
float* quantized_weight_data = float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
// Determine whether this weight tensor has been re-writed, avoiding
// re-write it again when this weight tensor is shared among many ops.
if (!quantized_op_weight_node_set.count(quantized_op_weight_node)) {
quantized_op_weight_node_set.insert(quantized_op_weight_node);
// If quantized op is fc, weight scale size = 1; // If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1] // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
...@@ -537,13 +542,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -537,13 +542,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(weight_scale.size(),
weight_scale.size(),
static_cast<size_t>(w_dims[1]), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by " "mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale " "[fake_channel_wise_dequantize_max_abs] "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " "requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's "
"weight, which is %d, "
"but got " "but got "
"%d.", "%d.",
static_cast<size_t>(w_dims[1]), static_cast<size_t>(w_dims[1]),
...@@ -562,7 +568,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -562,7 +568,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
"[fake_channel_wise_dequantize_max_abs], but got %s. " "[fake_channel_wise_dequantize_max_abs], but got %s. "
"If you uses PaddleSlim to generate the quantized " "If you uses PaddleSlim to generate the quantized "
"model, please set the 'weight_quantize_type' params as " "model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model again.", "'channel_wise_abs_max' and generate the quantized model "
"again.",
dequant_type)); dequant_type));
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
...@@ -570,7 +577,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -570,7 +577,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
quant_axis == 0, quant_axis == 0,
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " "'quant_axis' of conv2d/depthwise_conv2d op weight "
"dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but " "by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
...@@ -616,13 +624,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, ...@@ -616,13 +624,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
weight_scale.size())); weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3]; int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type)); "Unsupported quantized op type: %s", quantized_op_type));
} }
}
// create new op_desc // create new op_desc
auto base_op_desc = *quantized_op_node->Op()->Proto(); auto base_op_desc = *quantized_op_node->Op()->Proto();
std::string new_input = quantized_op_input_node->Name(); std::string new_input = quantized_op_input_node->Name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册