未验证 提交 d6f9aafc 编写于 作者: G Guanghua Yu 提交者: GitHub

support output calibration_table.txt in onnx_format (#1353)

上级 0dd15555
......@@ -21,8 +21,7 @@ import functools
import paddle
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir)
sys.path[1] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import imagenet_reader as reader
from utility import add_arguments, print_arguments
......@@ -31,8 +30,8 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.")
add_arg('model_name', str, '__model__', "model filename for inference model")
add_arg('params_name', str, '__params__', "params filename for inference model")
add_arg('model_name', str, 'model.pdmodel', "model filename for inference model")
add_arg('params_name', str, 'model.pdiparams', "params filename for inference model")
add_arg('batch_size', int, 64, "Minibatch size.")
# yapf: enable
......
......@@ -118,7 +118,7 @@ quant_post_dynamic
quant_post_static
---------------
.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='__model__', save_params_filename='__params__', batch_size=16, batch_nums=None, scope=None, algo='KL', round_type='round', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', onnx_format=False, skip_tensor_list=None, optimize_model=False)
.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='model.pdmodel', save_params_filename='model.pdiparams', batch_size=16, batch_nums=None, scope=None, algo='KL', round_type='round', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', onnx_format=False, skip_tensor_list=None, optimize_model=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/quant/quanter.py>`_
......@@ -217,15 +217,15 @@ quant_post_static
target_vars=[out],
main_program=val_prog,
executor=exe,
model_filename='__model__',
params_filename='__params__')
model_filename='model.pdmodel',
params_filename='model.pdiparams')
quant_post_static(
executor=exe,
model_dir='./model_path',
quantize_model_path='./save_path',
sample_generator=val_reader,
model_filename='__model__',
params_filename='__params__',
model_filename='model.pdmodel',
params_filename='model.pdiparams',
batch_size=16,
batch_nums=10)
......
......@@ -135,8 +135,8 @@ def save_cls_model(model, input_shape, save_dir, data_type):
weight_bits=8,
activation_bits=8)
model_file = os.path.join(quantize_model_path, '__model__')
param_file = os.path.join(quantize_model_path, '__params__')
model_file = os.path.join(quantize_model_path, 'model.pdmodel')
param_file = os.path.join(quantize_model_path, 'model.pdiparams')
return model_file, param_file
......
......@@ -804,7 +804,8 @@ class AutoCompression:
test_program,
self._places,
self._quant_config,
scope=paddle.static.global_scope())
scope=paddle.static.global_scope(),
save_clip_ranges_path=self.final_dir)
feed_vars = [
test_program.global_block().var(name)
......
......@@ -16,9 +16,14 @@ import os
import copy
import json
import logging
import collections
import numpy as np
import paddle
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
......@@ -27,18 +32,17 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
try:
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import QuantWeightPass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
except:
pass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.layer_helper import LayerHelper
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
_logger.warning(
"Some functions fail to import, please update PaddlePaddle version to 2.3+"
)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
......@@ -97,23 +101,48 @@ _quant_config_default = {
}
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForTrainingPassV2(OutScaleForTrainingPass):
def __init__(self, scope=None, place=None, moving_rate=0.9):
OutScaleForTrainingPass.__init__(
self, scope=scope, place=place, moving_rate=moving_rate)
def _scale_name(self, var_name):
class OutScaleForInferencePassV2(object):
def __init__(self, scope=None):
"""
Return the scale name for the var named `var_name`.
This pass is used for setting output scales of some operators.
These output scales may be used by tensorRT or some other inference engines.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
"""
return "%s@scale" % (var_name)
self._scope = scope
self._teller_set = utils._out_scale_op_list
def apply(self, graph):
"""
Get output scales from the scope and set these scales in op_descs
of operators in the teller_set.
# TODO: Hard-code, remove it when Paddle 2.3.1
class OutScaleForInferencePassV2(OutScaleForInferencePass):
def __init__(self, scope=None):
OutScaleForInferencePass.__init__(self, scope=scope)
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
collect_dict = collections.OrderedDict()
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in self._teller_set:
var_names = utils._get_op_output_var_names(op_node)
for var_name in var_names:
in_node = graph._find_node_by_name(op_node.outputs,
var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
collect_dict[var_name] = {}
scale_name = self._scale_name(var_name)
scale_var = self._scope.find_var(scale_name)
assert scale_var is not None, \
"Can not find {} variable in the scope".format(scale_name)
scale_value = np.array(scale_var.get_tensor())[0]
collect_dict[var_name]['scale'] = float(scale_value)
return graph, collect_dict
def _scale_name(self, var_name):
"""
......@@ -328,7 +357,7 @@ def quant_aware(program,
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPassV2(
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
......@@ -361,8 +390,8 @@ def quant_post_static(
data_loader=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
batch_size=1,
batch_nums=None,
scope=None,
......@@ -410,9 +439,9 @@ def quant_post_static(
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: '__model__'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: 'model.pdmodel'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: '__params__'.
If it is set None, parameters will be saved in separate files. Default: 'model.pdiparams'.
batch_size(int, optional): The batch size of DataLoader, default is 1.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
......@@ -513,6 +542,22 @@ def quant_post_static(
quantize_model_path,
model_filename=save_model_filename,
params_filename=save_params_filename)
if onnx_format:
try:
collect_dict = post_training_quantization._calibration_scales
save_quant_table_path = os.path.join(quantize_model_path,
'calibration_table.txt')
with open(save_quant_table_path, 'w') as txt_file:
for tensor_name in collect_dict.keys():
write_line = '{} {}'.format(
tensor_name, collect_dict[tensor_name]['scale']) + '\n'
txt_file.write(write_line)
_logger.info("Quantization clip ranges of tensors is save in: {}".
format(save_quant_table_path))
except:
_logger.warning(
"Unable to generate `calibration_table.txt`, please update PaddlePaddle >= 2.3.3"
)
# We have changed the quant_post to quant_post_static.
......@@ -521,7 +566,12 @@ def quant_post_static(
quant_post = quant_post_static
def convert(program, place, config=None, scope=None, save_int8=False):
def convert(program,
place,
config=None,
scope=None,
save_int8=False,
save_clip_ranges_path='./'):
"""
convert quantized and well-trained ``program`` to final quantized
``program``that can be used to save ``inference model``.
......@@ -543,6 +593,7 @@ def convert(program, place, config=None, scope=None, save_int8=False):
save_int8: Whether to return ``program`` which model parameters'
dtype is ``int8``. This parameter can only be used to
get model size. Default: ``False``.
save_clip_ranges_path: If config.onnx_format=True, quantization clip ranges will be saved locally.
Returns:
Tuple : freezed program which can be used for inference.
......@@ -563,8 +614,19 @@ def convert(program, place, config=None, scope=None, save_int8=False):
if config['onnx_format']:
quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph)
else:
out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope)
_, collect_dict = out_scale_infer_pass.apply(test_graph)
save_quant_table_path = os.path.join(save_clip_ranges_path,
'calibration_table.txt')
with open(save_quant_table_path, 'w') as txt_file:
for tensor_name in collect_dict.keys():
write_line = '{} {}'.format(
tensor_name, collect_dict[tensor_name]['scale']) + '\n'
txt_file.write(write_line)
_logger.info("Quantization clip ranges of tensors is save in: {}".
format(save_quant_table_path))
else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
......
......@@ -132,8 +132,8 @@ class TestQuantAwareCase1(StaticCase):
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference',
executor=exe,
model_filename='__model__',
params_filename='__params__')
model_filename='model.pdmodel',
params_filename='model.pdiparams')
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册