diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 90caee6c7a947023317f76874e8eac83c8b249f2..9b2954b13f222246bd8773e516e12a69381998af 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1292,10 +1292,11 @@ class QuantizationFreezePass(object): var_type=output_var_node.type(), shape=output_var_node.shape(), var_dtype=output_var_node.dtype()) + x_num_col_dims = 1 + if op_node.name() in ['matmul', 'matmul_v2', 'mul']: + x_num_col_dims = len(op_node.outputs[0].shape()) - 1 if op_node.op().has_attr("x_num_col_dims"): x_num_col_dims = op_node.op().attr("x_num_col_dims") - else: - x_num_col_dims = 1 dequant_op_node = graph.create_op_node( op_type='fake_channel_wise_dequantize_max_abs', attrs={