未验证 提交 03619037 编写于 作者: G Guanghua Yu 提交者: GitHub

Skip the int input operator when inserting a quant node & fix some bug (#49926)

上级 3a73d348
......@@ -2890,6 +2890,19 @@ class AddQuantDequantPassV2:
)
if in_node.persistable():
continue
if in_node.dtype() not in [
paddle.float64,
paddle.float32,
paddle.float16,
]:
_logger.warning(
"Since the {} contains an input of type INT, the quantization of this layer is skipped.".format(
op_node.name()
)
)
break
if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name]
else:
......@@ -3137,7 +3150,7 @@ class QuantWeightPass:
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = set()
self._quantized_ops = {}
def apply(self, graph):
assert isinstance(
......@@ -3176,7 +3189,6 @@ class QuantWeightPass:
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
......@@ -3211,10 +3223,13 @@ class QuantWeightPass:
self._scope,
self._place,
)
self._quantized_ops[x_node.name()] = quant_weight_node
for next_op_node in out_node.outputs:
graph.update_input_link(
out_node, quant_weight_node, next_op_node
out_node,
self._quantized_ops[x_node.name()],
next_op_node,
)
graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph)
......@@ -3298,9 +3313,9 @@ class AddQuantDequantForInferencePass:
op_node.outputs, var_name
)
if out_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
paddle.float64,
paddle.float32,
paddle.float16,
]:
continue
if var_name in dequantized_vars_map:
......@@ -3319,7 +3334,10 @@ class AddQuantDequantForInferencePass:
else:
var_names = utils._get_op_input_var_names(op_node)
for var_name in var_names:
if var_name in dequant_node_map:
if (
var_name in dequant_node_map
and dequant_node_map[var_name]
):
in_node = graph._find_node_by_name(
op_node.inputs, var_name
)
......@@ -3345,39 +3363,41 @@ class AddQuantDequantForInferencePass:
shape=var_node.shape(),
var_dtype=var_node.dtype(),
)
if not self._calibration_range_dict:
try:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
try:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
)
except:
if (
self._calibration_range_dict
and var_name in self._calibration_range_dict
):
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
except:
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
_logger.warning(
"Cannot find the target node {} in scope, so skip adding quant node.".format(
var_name
)
)
return None
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
return None
try:
zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册