From ddc244d3a2e36ab045f7f8143907bcd2c9957027 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 6 Sep 2022 14:50:58 +0800 Subject: [PATCH] [Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is shared between ops (#45719) * fix_old_format * fix bug in quant_conv2d_dequant * fix bug in quant_conv2d_dequant --- .../ir/quant_conv2d_dequant_fuse_pass.cc | 197 +++++++++--------- 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 1ff738aeedd..48722ba941a 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -440,7 +440,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, // Create pattern patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name); pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name); - + // Record whether quantized_op_weight_node has been dealt with + std::unordered_set quantized_op_weight_node_set; // Create new op desc auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -507,32 +508,108 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, const auto& w_dims = weight_tensor->dims(); float* quantized_weight_data = weight_tensor->mutable_data(platform::CPUPlace()); - // If quantized op is fc, weight scale size = 1; - // If quantized op is conv2d, weight scale size = weight dims[0] - // If quantized op is conv2d_transpose, weight scale size = weight dims[1] - if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { - if (dequant_type == "fake_dequantize_max_abs") { - PADDLE_ENFORCE_EQ(weight_scale.size(), - 1, - platform::errors::InvalidArgument( - "mul/matmul/matmul_v2 op weight dequantized by " - "[fake_dequantize_max_abs] " - "requires weight scale size = 1, but got %d.", - weight_scale.size())); + // Determine whether this weight tensor has been re-writed, avoiding + // re-write it again when this weight tensor is shared among many ops. + if (!quantized_op_weight_node_set.count(quantized_op_weight_node)) { + quantized_op_weight_node_set.insert(quantized_op_weight_node); + // If quantized op is fc, weight scale size = 1; + // If quantized op is conv2d, weight scale size = weight dims[0] + // If quantized op is conv2d_transpose, weight scale size = weight dims[1] + if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { + if (dequant_type == "fake_dequantize_max_abs") { + PADDLE_ENFORCE_EQ(weight_scale.size(), + 1, + platform::errors::InvalidArgument( + "mul/matmul/matmul_v2 op weight dequantized by " + "[fake_dequantize_max_abs] " + "requires weight scale size = 1, but got %d.", + weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + quantized_weight_data[j] *= weight_scale[0]; + } + } + if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 1, + true, + platform::errors::InvalidArgument( + "'quant_axis' of mul/matmul/fc/matmul_v2 op weight " + "dequantized by " + "[fake_channel_wise_dequantize_max_abs]should be 1, but " + "the received is %d", + quant_axis)); + } + PADDLE_ENFORCE_EQ(weight_scale.size(), + static_cast(w_dims[1]), + platform::errors::InvalidArgument( + "mul/matmul/matmul_v2 op weight dequantized by " + "[fake_channel_wise_dequantize_max_abs] " + "requires weight scale " + "size = 2nd dim of mul/matmul/matmul_v2's " + "weight, which is %d, " + "but got " + "%d.", + static_cast(w_dims[1]), + weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; + } + } + } else if (quantized_op_type == "conv2d" || + quantized_op_type == "depthwise_conv2d") { + PADDLE_ENFORCE_EQ( + dequant_type, + "fake_channel_wise_dequantize_max_abs", + platform::errors::InvalidArgument( + "conv2d op must be dequantized by " + "[fake_channel_wise_dequantize_max_abs], but got %s. " + "If you uses PaddleSlim to generate the quantized " + "model, please set the 'weight_quantize_type' params as " + "'channel_wise_abs_max' and generate the quantized model " + "again.", + dequant_type)); + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 0, + true, + platform::errors::InvalidArgument( + "'quant_axis' of conv2d/depthwise_conv2d op weight " + "dequantized " + "by [fake_channel_wise_dequantize_max_abs]should be 0, but " + "the received is %d", + quant_axis)); + } + PADDLE_ENFORCE_EQ( + weight_scale.size(), + static_cast(w_dims[0]), + platform::errors::InvalidArgument( + "conv2d op requires weight scale size = channel size of the " + "weight, which is %d, but got %d.", + static_cast(w_dims[0]), + weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { - quantized_weight_data[j] *= weight_scale[0]; + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= weight_scale[j / inner_size]; } - } - if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + } else if (quantized_op_type == "conv2d_transpose") { + PADDLE_ENFORCE_EQ( + dequant_type, + "fake_channel_wise_dequantize_max_abs", + platform::errors::InvalidArgument( + "conv2d_transpose must be dequantized by " + "[fake_channel_wise_dequantize_max_abs], but got %s", + dequant_type)); if (quant_axis == 0) { } else { PADDLE_ENFORCE_EQ( quant_axis == 1, true, platform::errors::InvalidArgument( - "'quant_axis' of mul/matmul/fc/matmul_v2 op weight " - "dequantized by " + "'quant_axis' of conv2d_transpose op weight dequantized by " "[fake_channel_wise_dequantize_max_abs]should be 1, but " "the received is %d", quant_axis)); @@ -541,88 +618,20 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument( - "mul/matmul/matmul_v2 op weight dequantized by " - "[fake_channel_wise_dequantize_max_abs] requires weight scale " - "size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " - "but got " - "%d.", + "conv2d_transpose op requires weight scale size = channel size " + "of the weight, which is %d, but got %d.", static_cast(w_dims[1]), weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { - quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; + int inner_size = w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= + weight_scale[(j / inner_size) % w_dims[1]]; } - } - } else if (quantized_op_type == "conv2d" || - quantized_op_type == "depthwise_conv2d") { - PADDLE_ENFORCE_EQ( - dequant_type, - "fake_channel_wise_dequantize_max_abs", - platform::errors::InvalidArgument( - "conv2d op must be dequantized by " - "[fake_channel_wise_dequantize_max_abs], but got %s. " - "If you uses PaddleSlim to generate the quantized " - "model, please set the 'weight_quantize_type' params as " - "'channel_wise_abs_max' and generate the quantized model again.", - dequant_type)); - if (quant_axis == 0) { } else { - PADDLE_ENFORCE_EQ( - quant_axis == 0, - true, - platform::errors::InvalidArgument( - "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " - "by [fake_channel_wise_dequantize_max_abs]should be 0, but " - "the received is %d", - quant_axis)); + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); } - PADDLE_ENFORCE_EQ( - weight_scale.size(), - static_cast(w_dims[0]), - platform::errors::InvalidArgument( - "conv2d op requires weight scale size = channel size of the " - "weight, which is %d, but got %d.", - static_cast(w_dims[0]), - weight_scale.size())); - for (int j = 0; j < weight_tensor->numel(); j++) { - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - quantized_weight_data[j] *= weight_scale[j / inner_size]; - } - } else if (quantized_op_type == "conv2d_transpose") { - PADDLE_ENFORCE_EQ( - dequant_type, - "fake_channel_wise_dequantize_max_abs", - platform::errors::InvalidArgument( - "conv2d_transpose must be dequantized by " - "[fake_channel_wise_dequantize_max_abs], but got %s", - dequant_type)); - if (quant_axis == 0) { - } else { - PADDLE_ENFORCE_EQ( - quant_axis == 1, - true, - platform::errors::InvalidArgument( - "'quant_axis' of conv2d_transpose op weight dequantized by " - "[fake_channel_wise_dequantize_max_abs]should be 1, but " - "the received is %d", - quant_axis)); - } - PADDLE_ENFORCE_EQ( - weight_scale.size(), - static_cast(w_dims[1]), - platform::errors::InvalidArgument( - "conv2d_transpose op requires weight scale size = channel size " - "of the weight, which is %d, but got %d.", - static_cast(w_dims[1]), - weight_scale.size())); - for (int j = 0; j < weight_tensor->numel(); j++) { - int inner_size = w_dims[2] * w_dims[3]; - quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported quantized op type: %s", quantized_op_type)); } - // create new op_desc auto base_op_desc = *quantized_op_node->Op()->Proto(); std::string new_input = quantized_op_input_node->Name(); -- GitLab