From 0361903789ea754079a95fe7df5876196fdf9ed7 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 1 Feb 2023 16:33:04 +0800 Subject: [PATCH] Skip the int input operator when inserting a quant node & fix some bug (#49926) --- .../static/quantization/quantization_pass.py | 86 ++++++++++++------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/python/paddle/static/quantization/quantization_pass.py b/python/paddle/static/quantization/quantization_pass.py index 83587563c4..c9094998df 100644 --- a/python/paddle/static/quantization/quantization_pass.py +++ b/python/paddle/static/quantization/quantization_pass.py @@ -2890,6 +2890,19 @@ class AddQuantDequantPassV2: ) if in_node.persistable(): continue + + if in_node.dtype() not in [ + paddle.float64, + paddle.float32, + paddle.float16, + ]: + _logger.warning( + "Since the {} contains an input of type INT, the quantization of this layer is skipped.".format( + op_node.name() + ) + ) + break + if arg_name in dequantized_vars_map: dequant_var_node = dequantized_vars_map[arg_name] else: @@ -3137,7 +3150,7 @@ class QuantWeightPass: self._save_int_weight = save_int_weight assert self._scope is not None, "scope must not be None." assert self._place is not None, "place must not be None." - self._quantized_ops = set() + self._quantized_ops = {} def apply(self, graph): assert isinstance( @@ -3176,7 +3189,6 @@ class QuantWeightPass: quant_axis = _op.op().attr("quant_axis") bits_length = _op.op().attr("bit_length") if x_node.name() not in self._quantized_ops: - self._quantized_ops.add(x_node.name()) quantized_param_v = utils.quant_tensor( param_v.copy(), scale_v, @@ -3211,10 +3223,13 @@ class QuantWeightPass: self._scope, self._place, ) + self._quantized_ops[x_node.name()] = quant_weight_node for next_op_node in out_node.outputs: graph.update_input_link( - out_node, quant_weight_node, next_op_node + out_node, + self._quantized_ops[x_node.name()], + next_op_node, ) graph.safe_remove_nodes(_op) self._remove_unused_var_nodes(graph) @@ -3298,9 +3313,9 @@ class AddQuantDequantForInferencePass: op_node.outputs, var_name ) if out_node.dtype() not in [ - core.VarDesc.VarType.FP64, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + paddle.float64, + paddle.float32, + paddle.float16, ]: continue if var_name in dequantized_vars_map: @@ -3319,7 +3334,10 @@ class AddQuantDequantForInferencePass: else: var_names = utils._get_op_input_var_names(op_node) for var_name in var_names: - if var_name in dequant_node_map: + if ( + var_name in dequant_node_map + and dequant_node_map[var_name] + ): in_node = graph._find_node_by_name( op_node.inputs, var_name ) @@ -3345,39 +3363,41 @@ class AddQuantDequantForInferencePass: shape=var_node.shape(), var_dtype=var_node.dtype(), ) - if not self._calibration_range_dict: - try: - scale_var_node = graph._find_node_by_name( - graph.all_persistable_nodes(), self._scale_name(var_name) + + try: + scale_var_node = graph._find_node_by_name( + graph.all_persistable_nodes(), self._scale_name(var_name) + ) + except: + if ( + self._calibration_range_dict + and var_name in self._calibration_range_dict + ): + scale_value = self._calibration_range_dict[var_name] + scale_var_node = graph.create_persistable_node( + name=self._scale_name(var_name), + var_type=var_node.type(), + shape=[1], + var_dtype=var_node.dtype(), ) - except: + data_type = ( + 'float64' + if var_node.dtype() == core.VarDesc.VarType.FP64 + else 'float32' + ) + _init_var_node( + scale_var_node, + np.array(scale_value, dtype=data_type), + self._scope, + self._place, + ) + else: _logger.warning( "Cannot find the target node {} in scope, so skip adding quant node.".format( var_name ) ) return None - elif var_name in self._calibration_range_dict: - scale_value = self._calibration_range_dict[var_name] - scale_var_node = graph.create_persistable_node( - name=self._scale_name(var_name), - var_type=var_node.type(), - shape=[1], - var_dtype=var_node.dtype(), - ) - data_type = ( - 'float64' - if var_node.dtype() == core.VarDesc.VarType.FP64 - else 'float32' - ) - _init_var_node( - scale_var_node, - np.array(scale_value, dtype=data_type), - self._scope, - self._place, - ) - else: - return None try: zero_point_node = graph._find_node_by_name( graph.all_persistable_nodes(), -- GitLab