未验证 提交 837773c1 编写于 作者: G Guanghua Yu 提交者: GitHub

update quantization new format (#1425)

上级 dfcceac0
...@@ -43,9 +43,10 @@ try: ...@@ -43,9 +43,10 @@ try:
from paddle.fluid.contrib.slim.quantization import QuantWeightPass from paddle.fluid.contrib.slim.quantization import QuantWeightPass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2 from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePassV2
except: except:
_logger.warning( _logger.warning(
"Some functions fail to import, please update PaddlePaddle version to 2.3+" "Some functions fail to import, please update PaddlePaddle version to 2.4+"
) )
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
...@@ -109,56 +110,6 @@ _quant_config_default = { ...@@ -109,56 +110,6 @@ _quant_config_default = {
} }
class OutScaleForInferencePassV2(object):
def __init__(self, scope=None):
"""
This pass is used for setting output scales of some operators.
These output scales may be used by tensorRT or some other inference engines.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
"""
self._scope = scope
self._teller_set = utils._out_scale_op_list
def apply(self, graph):
"""
Get output scales from the scope and set these scales in op_descs
of operators in the teller_set.
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
collect_dict = collections.OrderedDict()
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in self._teller_set:
var_names = utils._get_op_output_var_names(op_node)
for var_name in var_names:
in_node = graph._find_node_by_name(op_node.outputs,
var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
collect_dict[var_name] = {}
scale_name = self._scale_name(var_name)
scale_var = self._scope.find_var(scale_name)
assert scale_var is not None, \
"Can not find {} variable in the scope".format(scale_name)
scale_value = np.array(scale_var.get_tensor())[0]
collect_dict[var_name]['scale'] = float(scale_value)
return graph, collect_dict
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s@scale" % (var_name)
def load_dict(): def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file: with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read() data = file.read()
...@@ -515,8 +466,7 @@ def quant_aware(program, ...@@ -515,8 +466,7 @@ def quant_aware(program,
return quant_program return quant_program
def quant_post_static( def quant_post_static(executor,
executor,
model_dir, model_dir,
quantize_model_path, quantize_model_path,
batch_generator=None, batch_generator=None,
...@@ -533,7 +483,10 @@ def quant_post_static( ...@@ -533,7 +483,10 @@ def quant_post_static(
round_type='round', round_type='round',
hist_percent=0.9999, hist_percent=0.9999,
bias_correction=False, bias_correction=False,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=[
"conv2d", "depthwise_conv2d", "mul", "matmul",
"matmul_v2"
],
is_full_quantize=False, is_full_quantize=False,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
...@@ -676,22 +629,6 @@ def quant_post_static( ...@@ -676,22 +629,6 @@ def quant_post_static(
quantize_model_path, quantize_model_path,
model_filename=save_model_filename, model_filename=save_model_filename,
params_filename=save_params_filename) params_filename=save_params_filename)
if onnx_format:
try:
collect_dict = post_training_quantization._calibration_scales
save_quant_table_path = os.path.join(quantize_model_path,
'calibration_table.txt')
with open(save_quant_table_path, 'w') as txt_file:
for tensor_name in collect_dict.keys():
write_line = '{} {}'.format(
tensor_name, collect_dict[tensor_name]['scale']) + '\n'
txt_file.write(write_line)
_logger.info("Quantization clip ranges of tensors is save in: {}".
format(save_quant_table_path))
except:
_logger.warning(
"Unable to generate `calibration_table.txt`, please update PaddlePaddle >= 2.3.3"
)
# We have changed the quant_post to quant_post_static. # We have changed the quant_post to quant_post_static.
...@@ -748,17 +685,14 @@ def convert(program, ...@@ -748,17 +685,14 @@ def convert(program,
if config['onnx_format']: if config['onnx_format']:
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph) quant_weight_pass.apply(test_graph)
out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope) try:
_, collect_dict = out_scale_infer_pass.apply(test_graph) out_scale_infer_pass = OutScaleForInferencePassV2(
save_quant_table_path = os.path.join(save_clip_ranges_path, scope=scope, place=place, quant_bits=config['activation_bits'])
'calibration_table.txt') out_scale_infer_pass.apply(test_graph)
with open(save_quant_table_path, 'w') as txt_file: except:
for tensor_name in collect_dict.keys(): _logger.warning(
write_line = '{} {}'.format( "Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0"
tensor_name, collect_dict[tensor_name]['scale']) + '\n' )
txt_file.write(write_line)
_logger.info("Quantization clip ranges of tensors is save in: {}".
format(save_quant_table_path))
else: else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope) out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph) out_scale_infer_pass.apply(test_graph)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册