未验证 提交 b7f76647 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

Add quant axis (#36467)

* add_quant_axis

* add_quant_axis

* --amend

* Update quant_conv2d_dequant_fuse_pass.cc
上级 cbd15f7d
......@@ -437,7 +437,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale;
int quant_axis = 0;
if (dequant_op_node->Op()->HasAttr("quant_axis")) {
quant_axis =
BOOST_GET_CONST(int, dequant_op_node->Op()->GetAttr("quant_axis"));
}
// Get weight scale
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
Node* dequant_channel_scale_node =
......@@ -488,6 +492,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
}
}
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
......@@ -511,6 +525,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
"model, please set the 'weight_quantize_type' params as "
"'channel_wise_abs_max' and generate the quantized model again.",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 0, true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
......@@ -528,6 +552,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
if (quant_axis == 0) {
} else {
PADDLE_ENFORCE_EQ(
quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册