未验证 提交 03619037 编写于 作者: G Guanghua Yu 提交者: GitHub

Skip the int input operator when inserting a quant node & fix some bug (#49926)

上级 3a73d348
...@@ -2890,6 +2890,19 @@ class AddQuantDequantPassV2: ...@@ -2890,6 +2890,19 @@ class AddQuantDequantPassV2:
) )
if in_node.persistable(): if in_node.persistable():
continue continue
if in_node.dtype() not in [
paddle.float64,
paddle.float32,
paddle.float16,
]:
_logger.warning(
"Since the {} contains an input of type INT, the quantization of this layer is skipped.".format(
op_node.name()
)
)
break
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name] dequant_var_node = dequantized_vars_map[arg_name]
else: else:
...@@ -3137,7 +3150,7 @@ class QuantWeightPass: ...@@ -3137,7 +3150,7 @@ class QuantWeightPass:
self._save_int_weight = save_int_weight self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None." assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None." assert self._place is not None, "place must not be None."
self._quantized_ops = set() self._quantized_ops = {}
def apply(self, graph): def apply(self, graph):
assert isinstance( assert isinstance(
...@@ -3176,7 +3189,6 @@ class QuantWeightPass: ...@@ -3176,7 +3189,6 @@ class QuantWeightPass:
quant_axis = _op.op().attr("quant_axis") quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length") bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops: if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor( quantized_param_v = utils.quant_tensor(
param_v.copy(), param_v.copy(),
scale_v, scale_v,
...@@ -3211,10 +3223,13 @@ class QuantWeightPass: ...@@ -3211,10 +3223,13 @@ class QuantWeightPass:
self._scope, self._scope,
self._place, self._place,
) )
self._quantized_ops[x_node.name()] = quant_weight_node
for next_op_node in out_node.outputs: for next_op_node in out_node.outputs:
graph.update_input_link( graph.update_input_link(
out_node, quant_weight_node, next_op_node out_node,
self._quantized_ops[x_node.name()],
next_op_node,
) )
graph.safe_remove_nodes(_op) graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
...@@ -3298,9 +3313,9 @@ class AddQuantDequantForInferencePass: ...@@ -3298,9 +3313,9 @@ class AddQuantDequantForInferencePass:
op_node.outputs, var_name op_node.outputs, var_name
) )
if out_node.dtype() not in [ if out_node.dtype() not in [
core.VarDesc.VarType.FP64, paddle.float64,
core.VarDesc.VarType.FP32, paddle.float32,
core.VarDesc.VarType.FP16, paddle.float16,
]: ]:
continue continue
if var_name in dequantized_vars_map: if var_name in dequantized_vars_map:
...@@ -3319,7 +3334,10 @@ class AddQuantDequantForInferencePass: ...@@ -3319,7 +3334,10 @@ class AddQuantDequantForInferencePass:
else: else:
var_names = utils._get_op_input_var_names(op_node) var_names = utils._get_op_input_var_names(op_node)
for var_name in var_names: for var_name in var_names:
if var_name in dequant_node_map: if (
var_name in dequant_node_map
and dequant_node_map[var_name]
):
in_node = graph._find_node_by_name( in_node = graph._find_node_by_name(
op_node.inputs, var_name op_node.inputs, var_name
) )
...@@ -3345,39 +3363,41 @@ class AddQuantDequantForInferencePass: ...@@ -3345,39 +3363,41 @@ class AddQuantDequantForInferencePass:
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
if not self._calibration_range_dict:
try: try:
scale_var_node = graph._find_node_by_name( scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name) graph.all_persistable_nodes(), self._scale_name(var_name)
)
except:
if (
self._calibration_range_dict
and var_name in self._calibration_range_dict
):
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
) )
except: data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
_logger.warning( _logger.warning(
"Cannot find the target node {} in scope, so skip adding quant node.".format( "Cannot find the target node {} in scope, so skip adding quant node.".format(
var_name var_name
) )
) )
return None return None
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
return None
try: try:
zero_point_node = graph._find_node_by_name( zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(), graph.all_persistable_nodes(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册