未验证 提交 5993dde3 编写于 作者: C cc 提交者: GitHub

[Cherry-pick] Fix some bugs for quantization (#24852)

* Update sigmoid output from Y to out, test=develop (#24765)

* Collecting concat output threshold, test=develop (#24742)

* Add output threshold for ops that have several output activations, test=develop (#24726)

* [Fix bug] Init scale node in OutScaleForTrainingPass and enable test_quantization_scale_pass UT (#24393)

* Init scale node in OutScaleForTrainingPass, test=develop
* Enable test_quantization_scale, test=develop
上级 8c40ebd1
...@@ -43,7 +43,7 @@ _fake_quant_dequant_op_list = [ ...@@ -43,7 +43,7 @@ _fake_quant_dequant_op_list = [
_out_scale_op_list = [ _out_scale_op_list = [
"conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu", "conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
"relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm", "relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
"elementwise_add", "pool2d", "reshape2", "transpose2" "elementwise_add", "pool2d", "reshape2", "transpose2", "concat"
] ]
# list op real input and output names, to avoid processing input such as AxisTensor. # list op real input and output names, to avoid processing input such as AxisTensor.
...@@ -83,7 +83,7 @@ _op_real_in_out_name = { ...@@ -83,7 +83,7 @@ _op_real_in_out_name = {
"swish": [["X"], ["Out"]], "swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]], "dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]], "batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Y"]], "sigmoid": [["X"], ["Out"]],
} }
...@@ -1156,20 +1156,27 @@ class OutScaleForTrainingPass(object): ...@@ -1156,20 +1156,27 @@ class OutScaleForTrainingPass(object):
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() self._is_test = graph.is_test()
ops = graph.all_op_nodes() target_ops = []
for op_node in ops: for op in graph.all_op_nodes():
name = op_node.name() if op.name() in self._teller_set:
if name in self._teller_set: target_ops.append(op)
if len(op_node.output_arg_names()) != 1: for op in target_ops:
continue for output_var_name in _get_op_output_var_names(op):
in_node = graph._find_node_by_name( in_node = graph._find_node_by_name(op.outputs, output_var_name)
op_node.outputs, op_node.output_arg_names()[0])
out_node = graph.create_var_node_from_desc(in_node.var()) out_node = graph.create_var_node_from_desc(in_node.var())
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=in_node.dtype()) var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
ins = {'X': in_node} ins = {'X': in_node}
outs = {'Out': out_node, 'OutScale': scale_node} outs = {'Out': out_node, 'OutScale': scale_node}
if not self._is_test: if not self._is_test:
...@@ -1178,8 +1185,6 @@ class OutScaleForTrainingPass(object): ...@@ -1178,8 +1185,6 @@ class OutScaleForTrainingPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), var_dtype=in_node.dtype(),
shape=[1]) shape=[1])
data_type = 'float64' if in_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(
state_in_node, state_in_node,
np.ones( np.ones(
...@@ -1257,13 +1262,13 @@ class OutScaleForInferencePass(object): ...@@ -1257,13 +1262,13 @@ class OutScaleForInferencePass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes() op_nodes = graph.all_op_nodes()
for op_node in ops: for op_node in op_nodes:
name = op_node.name() if op_node.name() in self._teller_set:
if name in self._teller_set: output_var_name = _get_op_output_var_names(op_node)
if len(op_node.output_arg_names()) != 1: assert len(output_var_name) == 1, "Only support collecting " \
continue "output for op that only has an activation output for now."
scale_name = self._scale_name(op_node.output_arg_names()[0]) scale_name = self._scale_name(output_var_name[0])
scale_v = np.array( scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0] self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_threshold", float(scale_v)) op_node.op()._set_attr("out_threshold", float(scale_v))
......
...@@ -114,9 +114,6 @@ if(WIN32) ...@@ -114,9 +114,6 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif() endif()
# Disable unittest for random error temporary
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)
if(LINUX AND WITH_MKLDNN) if(LINUX AND WITH_MKLDNN)
#### Image classification dataset: ImageNet (small) #### Image classification dataset: ImageNet (small)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册