未验证 提交 a6bd6957 编写于 作者: G Guanghua Yu 提交者: GitHub

Fix the problem that the quantization model cannot find the weight (#49664)

上级 c70fe47c
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
import logging
import numpy as np import numpy as np
...@@ -27,6 +28,7 @@ from ...fluid.framework import IrGraph, IrNode ...@@ -27,6 +28,7 @@ from ...fluid.framework import IrGraph, IrNode
from ...framework import _get_paddle_place, core from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name from ...utils import unique_name
from ..log_helper import get_logger
from . import utils from . import utils
from .quant_config import ( from .quant_config import (
SUPPORT_ACT_QUANTIZATION_OP_DICT, SUPPORT_ACT_QUANTIZATION_OP_DICT,
...@@ -34,6 +36,10 @@ from .quant_config import ( ...@@ -34,6 +36,10 @@ from .quant_config import (
SUPPORT_WEIGHT_QUANTIZATION_OP_DICT, SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
) )
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
_fake_quant_op_list = [ _fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_abs_max',
'fake_quantize_range_abs_max', 'fake_quantize_range_abs_max',
...@@ -3193,11 +3199,24 @@ class QuantWeightPass: ...@@ -3193,11 +3199,24 @@ class QuantWeightPass:
quantized_param_v = quantized_param_v.astype( quantized_param_v = quantized_param_v.astype(
save_weight_dtype save_weight_dtype
) )
self._restore_var(x_node.name(), quantized_param_v) quant_weight_node = graph.create_persistable_node(
name=self._quantized_var_name(x_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=x_node.shape(),
var_dtype=core.VarDesc.VarType.INT8,
)
_init_var_node(
quant_weight_node,
quantized_param_v,
self._scope,
self._place,
)
for next_op_node in out_node.outputs: for next_op_node in out_node.outputs:
graph.update_input_link(out_node, x_node, next_op_node) graph.update_input_link(
graph.safe_remove_nodes(out_node) out_node, quant_weight_node, next_op_node
)
graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
...@@ -3222,9 +3241,11 @@ class QuantWeightPass: ...@@ -3222,9 +3241,11 @@ class QuantWeightPass:
def _load_var(self, name): def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor()) return np.array(self._scope.find_var(name).get_tensor())
def _restore_var(self, name, array): def _quantized_var_name(self, var_name):
tensor = self._scope.find_var(name).get_tensor() """
tensor.set(array, self._place) Return quantized variable name for the input `var_name`.
"""
return "%s.quantized" % (var_name)
class AddQuantDequantForInferencePass: class AddQuantDequantForInferencePass:
...@@ -3325,9 +3346,17 @@ class AddQuantDequantForInferencePass: ...@@ -3325,9 +3346,17 @@ class AddQuantDequantForInferencePass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
if not self._calibration_range_dict: if not self._calibration_range_dict:
scale_var_node = graph._find_node_by_name( try:
graph.all_persistable_nodes(), self._scale_name(var_name) scale_var_node = graph._find_node_by_name(
) graph.all_persistable_nodes(), self._scale_name(var_name)
)
except:
_logger.warning(
"Cannot find the target node {} in scope, so skip adding quant node.".format(
var_name
)
)
return None
elif var_name in self._calibration_range_dict: elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name] scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node( scale_var_node = graph.create_persistable_node(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册