diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 1864899b07e0180d5d9f8d32d32a17245e992a81..22babcc719aeb424e98c70c7f7aed1f242548e3a 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -437,7 +437,11 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, BOOST_GET_CONST(int, quantized_op_node->Op()->GetAttr("bit_length")); int range = ((1 << (bit_length - 1)) - 1); std::vector weight_scale; - + int quant_axis = 0; + if (dequant_op_node->Op()->HasAttr("quant_axis")) { + quant_axis = + BOOST_GET_CONST(int, dequant_op_node->Op()->GetAttr("quant_axis")); + } // Get weight scale if (dequant_type == "fake_channel_wise_dequantize_max_abs") { Node* dequant_channel_scale_node = @@ -488,6 +492,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } } if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' of mul/matmul/fc op weight dequantized by " + "[fake_channel_wise_dequantize_max_abs]should be 1, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument( @@ -511,6 +525,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, "model, please set the 'weight_quantize_type' params as " "'channel_wise_abs_max' and generate the quantized model again.", dequant_type)); + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 0, true, + platform::errors::InvalidArgument( + "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " + "by [fake_channel_wise_dequantize_max_abs]should be 0, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[0]), platform::errors::InvalidArgument( @@ -528,6 +552,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, "conv2d_transpose must be dequantized by " "[fake_channel_wise_dequantize_max_abs], but got %s", dequant_type)); + if (quant_axis == 0) { + } else { + PADDLE_ENFORCE_EQ( + quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' of conv2d_transpose op weight dequantized by " + "[fake_channel_wise_dequantize_max_abs]should be 1, but " + "the received is %d", + quant_axis)); + } PADDLE_ENFORCE_EQ( weight_scale.size(), static_cast(w_dims[1]), platform::errors::InvalidArgument(