未验证 提交 ddc244d3 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is...

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is shared  between ops (#45719)

* fix_old_format

* fix bug in quant_conv2d_dequant

* fix bug in quant_conv2d_dequant
上级 4acf1ef7
......@@ -440,7 +440,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
// Create pattern
patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name);
// Record whether quantized_op_weight_node has been dealt with
std::unordered_set<Node*> quantized_op_weight_node_set;
// Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -507,6 +508,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
const auto& w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// Determine whether this weight tensor has been re-writed, avoiding
// re-write it again when this weight tensor is shared among many ops.
if (!quantized_op_weight_node_set.count(quantized_op_weight_node)) {
quantized_op_weight_node_set.insert(quantized_op_weight_node);
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
......@@ -537,13 +542,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(),
PADDLE_ENFORCE_EQ(weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"[fake_channel_wise_dequantize_max_abs] "
"requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's "
"weight, which is %d, "
"but got "
"%d.",
static_cast<size_t>(w_dims[1]),
......@@ -562,7 +568,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
"[fake_channel_wise_dequantize_max_abs], but got %s. "
"If you uses PaddleSlim to generate the quantized "
"model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model again.",
"'channel_wise_abs_max' and generate the quantized model "
"again.",
dequant_type));
if (quant_axis == 0) {
} else {
......@@ -570,7 +577,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
quant_axis == 0,
true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"'quant_axis' of conv2d/depthwise_conv2d op weight "
"dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d",
quant_axis));
......@@ -616,13 +624,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
}
// create new op_desc
auto base_op_desc = *quantized_op_node->Op()->Proto();
std::string new_input = quantized_op_input_node->Name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册