From a6bd69570477a3e0d1973ed39d2fee2fa41f37c7 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 10 Jan 2023 16:03:55 +0800 Subject: [PATCH] Fix the problem that the quantization model cannot find the weight (#49664) --- .../static/quantization/quantization_pass.py | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/python/paddle/static/quantization/quantization_pass.py b/python/paddle/static/quantization/quantization_pass.py index fc7ab7689e..83587563c4 100644 --- a/python/paddle/static/quantization/quantization_pass.py +++ b/python/paddle/static/quantization/quantization_pass.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import logging import numpy as np @@ -27,6 +28,7 @@ from ...fluid.framework import IrGraph, IrNode from ...framework import _get_paddle_place, core from ...static import Program, data, program_guard, scope_guard from ...utils import unique_name +from ..log_helper import get_logger from . import utils from .quant_config import ( SUPPORT_ACT_QUANTIZATION_OP_DICT, @@ -34,6 +36,10 @@ from .quant_config import ( SUPPORT_WEIGHT_QUANTIZATION_OP_DICT, ) +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + _fake_quant_op_list = [ 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', @@ -3193,11 +3199,24 @@ class QuantWeightPass: quantized_param_v = quantized_param_v.astype( save_weight_dtype ) - self._restore_var(x_node.name(), quantized_param_v) + quant_weight_node = graph.create_persistable_node( + name=self._quantized_var_name(x_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=x_node.shape(), + var_dtype=core.VarDesc.VarType.INT8, + ) + _init_var_node( + quant_weight_node, + quantized_param_v, + self._scope, + self._place, + ) for next_op_node in out_node.outputs: - graph.update_input_link(out_node, x_node, next_op_node) - graph.safe_remove_nodes(out_node) + graph.update_input_link( + out_node, quant_weight_node, next_op_node + ) + graph.safe_remove_nodes(_op) self._remove_unused_var_nodes(graph) def _remove_unused_var_nodes(self, graph): @@ -3222,9 +3241,11 @@ class QuantWeightPass: def _load_var(self, name): return np.array(self._scope.find_var(name).get_tensor()) - def _restore_var(self, name, array): - tensor = self._scope.find_var(name).get_tensor() - tensor.set(array, self._place) + def _quantized_var_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) class AddQuantDequantForInferencePass: @@ -3325,9 +3346,17 @@ class AddQuantDequantForInferencePass: var_dtype=var_node.dtype(), ) if not self._calibration_range_dict: - scale_var_node = graph._find_node_by_name( - graph.all_persistable_nodes(), self._scale_name(var_name) - ) + try: + scale_var_node = graph._find_node_by_name( + graph.all_persistable_nodes(), self._scale_name(var_name) + ) + except: + _logger.warning( + "Cannot find the target node {} in scope, so skip adding quant node.".format( + var_name + ) + ) + return None elif var_name in self._calibration_range_dict: scale_value = self._calibration_range_dict[var_name] scale_var_node = graph.create_persistable_node( -- GitLab