未验证 提交 d8f4714b 编写于 作者: C cc 提交者: GitHub

[Quantization] Save output threshold by argname_index (#25272)

* Save output threshold by argname_index, test=develop
上级 64b46122
......@@ -28,6 +28,7 @@ from .quantization_pass import AddQuantDequantPass
from .quantization_pass import _out_scale_op_list
from .quantization_pass import _get_op_input_var_names
from .quantization_pass import _get_op_output_var_names
from .quantization_pass import _get_output_name_index
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
......@@ -405,6 +406,10 @@ class PostTrainingQuantization(object):
model_filename=self._model_filename,
params_filename=self._params_filename)
if self._program.num_blocks > 1:
_logger.error("The post training quantization requires that the "
"program only has one block.")
if self._optimize_model:
self._optimize_fp32_model()
......@@ -450,6 +455,9 @@ class PostTrainingQuantization(object):
persistable_var_names = _all_persistable_var_names(self._program)
for op in self._program.global_block().ops:
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
_logger.warning(op_type + " is not supported for quantization.")
# For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type:
collect_var_name(
......@@ -685,13 +693,25 @@ class PostTrainingQuantization(object):
op._set_attr("quantization_type", quantized_type)
def analysis_and_save_info(op_node, out_var_name):
argname_index = _get_output_name_index(op_node, out_var_name)
assert argname_index is not None, \
out_var_name + " is not the output of the op"
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_kl_threshold, "out_threshold",
"post_kl")
save_info(
op_node, out_var_name, self._quantized_var_kl_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
elif self._algo == "abs_max":
save_info(op_node, out_var_name, self._quantized_var_abs_max,
"out_threshold", "post_abs_max")
save_info(
op_node, out_var_name, self._quantized_var_abs_max,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
elif self._algo == "min_max":
save_info(op_node, out_var_name, self._quantized_var_min,
"out_min", "post_min_max")
......
......@@ -127,6 +127,22 @@ def _get_op_output_var_names(op):
return var_names
def _get_output_name_index(op, output_var_name):
"""Get the output name and index of the var_name in the op"""
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \
else op.type
name_list = _op_real_in_out_name[op_name][1]
res = None
for name in name_list:
var_name = op.output(name)
for index, val in enumerate(var_name):
if val == output_var_name:
res = (name, index)
return res
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.'
......@@ -1528,13 +1544,19 @@ class OutScaleForInferencePass(object):
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in self._teller_set:
output_var_name = _get_op_output_var_names(op_node)
assert len(output_var_name) == 1, "Only support collecting " \
"output for op that only has an activation output for now."
scale_name = self._scale_name(output_var_name[0])
scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_threshold", float(scale_v))
var_names = _get_op_output_var_names(op_node)
for var_name in var_names:
# For compatibility, we save output threshold by two methods.
scale_name = self._scale_name(var_name)
scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_threshold", float(scale_v))
argname_index = _get_output_name_index(op_node, var_name)
assert argname_index is not None, \
var_name + " is not the output of the op"
op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \
+ "_threshold", float(scale_v))
graph.resolve_hazard()
return graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册