未验证 提交 a6bd6957 编写于 作者: G Guanghua Yu 提交者: GitHub

Fix the problem that the quantization model cannot find the weight (#49664)

上级 c70fe47c
......@@ -13,6 +13,7 @@
# limitations under the License.
import collections
import logging
import numpy as np
......@@ -27,6 +28,7 @@ from ...fluid.framework import IrGraph, IrNode
from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name
from ..log_helper import get_logger
from . import utils
from .quant_config import (
SUPPORT_ACT_QUANTIZATION_OP_DICT,
......@@ -34,6 +36,10 @@ from .quant_config import (
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
)
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
_fake_quant_op_list = [
'fake_quantize_abs_max',
'fake_quantize_range_abs_max',
......@@ -3193,11 +3199,24 @@ class QuantWeightPass:
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)
quant_weight_node = graph.create_persistable_node(
name=self._quantized_var_name(x_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=x_node.shape(),
var_dtype=core.VarDesc.VarType.INT8,
)
_init_var_node(
quant_weight_node,
quantized_param_v,
self._scope,
self._place,
)
for next_op_node in out_node.outputs:
graph.update_input_link(out_node, x_node, next_op_node)
graph.safe_remove_nodes(out_node)
graph.update_input_link(
out_node, quant_weight_node, next_op_node
)
graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph)
def _remove_unused_var_nodes(self, graph):
......@@ -3222,9 +3241,11 @@ class QuantWeightPass:
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
"""
return "%s.quantized" % (var_name)
class AddQuantDequantForInferencePass:
......@@ -3325,9 +3346,17 @@ class AddQuantDequantForInferencePass:
var_dtype=var_node.dtype(),
)
if not self._calibration_range_dict:
try:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
)
except:
_logger.warning(
"Cannot find the target node {} in scope, so skip adding quant node.".format(
var_name
)
)
return None
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册