From 66a1df3cd2e2eb516aec079bcebc09fad1f152cf Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Thu, 3 Nov 2022 10:23:05 +0800 Subject: [PATCH] Avoid Quant Weight Repeatedly (#47587) --- .../slim/quantization/quantization_pass.py | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 7a9b89866e..020bdcec48 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1119,6 +1119,7 @@ class QuantizationFreezePass(object): self._op_input_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict() self._quant_var_scale_map = collections.OrderedDict() + self._quantized_ops = set() def apply(self, graph): """ @@ -1173,24 +1174,27 @@ class QuantizationFreezePass(object): quant_axis = 1 else: quant_axis = 0 - quantized_param_v = utils.quant_tensor( - param_v.copy(), - scale_v, - quant_axis, - self._weight_bits, - ) - quantized_param_v = np.round(quantized_param_v) - # Weight bias correction - if self._bias_correction == True: - quantized_param_v = utils.bias_correction_w( - param_v, - quantized_param_v, + if input_arg_name not in self._quantized_ops: + self._quantized_ops.add(input_arg_name) + quantized_param_v = utils.quant_tensor( + param_v.copy(), scale_v, quant_axis, - weight_bits=self._weight_bits, + self._weight_bits, ) quantized_param_v = np.round(quantized_param_v) - self._restore_var(input_arg_name, quantized_param_v) + # Weight bias correction + if self._bias_correction == True: + quantized_param_v = utils.bias_correction_w( + param_v, + quantized_param_v, + scale_v, + quant_axis, + weight_bits=self._weight_bits, + ) + quantized_param_v = np.round(quantized_param_v) + self._restore_var(input_arg_name, quantized_param_v) + self._remove_fake_quant_and_dequant_op(graph, op_node) # Remove all fake dequant op @@ -3029,6 +3033,7 @@ class QuantWeightPass(object): self._save_int_weight = save_int_weight assert self._scope is not None, "scope must not be None." assert self._place is not None, "place must not be None." + self._quantized_ops = set() def apply(self, graph): assert isinstance( @@ -3066,29 +3071,31 @@ class QuantWeightPass(object): param_v = self._load_var(x_node.name()) quant_axis = _op.op().attr("quant_axis") bits_length = _op.op().attr("bit_length") - quantized_param_v = utils.quant_tensor( - param_v.copy(), - scale_v, - quant_axis, - bits_length, - onnx_format=True, - ) - if self._bias_correction == True: - quantized_param_v = utils.bias_correction_w( - param_v, - quantized_param_v, + if x_node.name() not in self._quantized_ops: + self._quantized_ops.add(x_node.name()) + quantized_param_v = utils.quant_tensor( + param_v.copy(), scale_v, quant_axis, - weight_bits=bits_length, + bits_length, + onnx_format=True, ) - if self._save_int_weight: - # cast weight type to int - if self._quant_bits == 8: - save_weight_dtype = np.int8 - quantized_param_v = quantized_param_v.astype( - save_weight_dtype - ) - self._restore_var(x_node.name(), quantized_param_v) + if self._bias_correction == True: + quantized_param_v = utils.bias_correction_w( + param_v, + quantized_param_v, + scale_v, + quant_axis, + weight_bits=bits_length, + ) + if self._save_int_weight: + # cast weight type to int + if self._quant_bits == 8: + save_weight_dtype = np.int8 + quantized_param_v = quantized_param_v.astype( + save_weight_dtype + ) + self._restore_var(x_node.name(), quantized_param_v) for next_op_node in out_node.outputs: graph.update_input_link(out_node, x_node, next_op_node) -- GitLab