未验证 提交 ddc244d3 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is...

[Paddle Inference] fix bugs in quant_conv2d_dequant_fuse_pass when weight is shared  between ops (#45719)

* fix_old_format

* fix bug in quant_conv2d_dequant

* fix bug in quant_conv2d_dequant
上级 4acf1ef7
......@@ -440,7 +440,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
// Create pattern
patterns::DequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(quantized_op_input, quantized_op_type, dequant_type, weight_name);
// Record whether quantized_op_weight_node has been dealt with
std::unordered_set<Node*> quantized_op_weight_node_set;
// Create new op desc
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -507,32 +508,108 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
const auto& w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(weight_scale.size(),
1,
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
// Determine whether this weight tensor has been re-writed, avoiding
// re-write it again when this weight tensor is shared among many ops.
if (!quantized_op_weight_node_set.count(quantized_op_weight_node)) {
quantized_op_weight_node_set.insert(quantized_op_weight_node);
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(weight_scale.size(),
1,
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[0];
}
}
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1,
true,
platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] "
"requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's "
"weight, which is %d, "
"but got "
"%d.",
static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
}
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ(
dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s. "
"If you uses PaddleSlim to generate the quantized "
"model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model "
"again.",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 0,
true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight "
"dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(),
static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[0];
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
}
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
} else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1,
true,
platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by "
"'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
......@@ -541,88 +618,20 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"but got "
"%d.",
"conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
}
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ(
dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s. "
"If you uses PaddleSlim to generate the quantized "
"model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model again.",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 0,
true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d",
quant_axis));
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(),
static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
} else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1,
true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
// create new op_desc
auto base_op_desc = *quantized_op_node->Op()->Proto();
std::string new_input = quantized_op_input_node->Name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册