diff --git a/demo/quant/quant_post/eval.py b/demo/quant/quant_post/eval.py index 310eacd08d833bdcae2fc9ede5323f602be0f2c3..e8184e848c3e0785c398ef3dcdf53a1903e3cf86 100755 --- a/demo/quant/quant_post/eval.py +++ b/demo/quant/quant_post/eval.py @@ -21,8 +21,7 @@ import functools import paddle sys.path[0] = os.path.join( os.path.dirname("__file__"), os.path.pardir, os.path.pardir) -sys.path[1] = os.path.join( - os.path.dirname("__file__"), os.path.pardir) +sys.path[1] = os.path.join(os.path.dirname("__file__"), os.path.pardir) import imagenet_reader as reader from utility import add_arguments, print_arguments @@ -31,8 +30,8 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.") -add_arg('model_name', str, '__model__', "model filename for inference model") -add_arg('params_name', str, '__params__', "params filename for inference model") +add_arg('model_name', str, 'model.pdmodel', "model filename for inference model") +add_arg('params_name', str, 'model.pdiparams', "params filename for inference model") add_arg('batch_size', int, 64, "Minibatch size.") # yapf: enable diff --git a/docs/zh_cn/api_cn/static/quant/quantization_api.rst b/docs/zh_cn/api_cn/static/quant/quantization_api.rst index a12e4e9b574ec672b2611c23f6d8aeb1785248fd..f2d7b77d54e87a7ae897f02fa391b77c2fae5e2f 100644 --- a/docs/zh_cn/api_cn/static/quant/quantization_api.rst +++ b/docs/zh_cn/api_cn/static/quant/quantization_api.rst @@ -118,7 +118,7 @@ quant_post_dynamic quant_post_static --------------- -.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='__model__', save_params_filename='__params__', batch_size=16, batch_nums=None, scope=None, algo='KL', round_type='round', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', onnx_format=False, skip_tensor_list=None, optimize_model=False) +.. py:function:: paddleslim.quant.quant_post_static(executor,model_dir, quantize_model_path, batch_generator=None, sample_generator=None, model_filename=None, params_filename=None, save_model_filename='model.pdmodel', save_params_filename='model.pdiparams', batch_size=16, batch_nums=None, scope=None, algo='KL', round_type='round', quantizable_op_type=["conv2d","depthwise_conv2d","mul"], is_full_quantize=False, weight_bits=8, activation_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', onnx_format=False, skip_tensor_list=None, optimize_model=False) `源代码 `_ @@ -217,15 +217,15 @@ quant_post_static target_vars=[out], main_program=val_prog, executor=exe, - model_filename='__model__', - params_filename='__params__') + model_filename='model.pdmodel', + params_filename='model.pdiparams') quant_post_static( executor=exe, model_dir='./model_path', quantize_model_path='./save_path', sample_generator=val_reader, - model_filename='__model__', - params_filename='__params__', + model_filename='model.pdmodel', + params_filename='model.pdiparams', batch_size=16, batch_nums=10) diff --git a/paddleslim/analysis/_utils.py b/paddleslim/analysis/_utils.py index 0b6fd1b855c02e94deb41c4ca0f5b6884996a9e7..82bedee53da83c8ce2fbf47f6cb27ee4cf080efd 100644 --- a/paddleslim/analysis/_utils.py +++ b/paddleslim/analysis/_utils.py @@ -135,8 +135,8 @@ def save_cls_model(model, input_shape, save_dir, data_type): weight_bits=8, activation_bits=8) - model_file = os.path.join(quantize_model_path, '__model__') - param_file = os.path.join(quantize_model_path, '__params__') + model_file = os.path.join(quantize_model_path, 'model.pdmodel') + param_file = os.path.join(quantize_model_path, 'model.pdiparams') return model_file, param_file diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index 0ccb480f59671393ec54a9ca4b3210e2827bef5f..dc825d08b1a5dc904006df9f50bf7e4b0f27438d 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -804,7 +804,8 @@ class AutoCompression: test_program, self._places, self._quant_config, - scope=paddle.static.global_scope()) + scope=paddle.static.global_scope(), + save_clip_ranges_path=self.final_dir) feed_vars = [ test_program.global_block().var(name) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 3ecff9bf31341d59b8f85ae9b087d0459b2ac8de..9f8ed323e3ab6eb86f9a476d9895d89dddf31eca 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -16,9 +16,14 @@ import os import copy import json import logging +import collections +import numpy as np import paddle +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass @@ -27,18 +32,17 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass +from ..common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + try: from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2 from paddle.fluid.contrib.slim.quantization import QuantWeightPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2 except: - pass -from paddle.fluid import core -from paddle.fluid.contrib.slim.quantization import WeightQuantization -from paddle.fluid.layer_helper import LayerHelper - -from ..common import get_logger -_logger = get_logger(__name__, level=logging.INFO) + _logger.warning( + "Some functions fail to import, please update PaddlePaddle version to 2.3+" + ) WEIGHT_QUANTIZATION_TYPES = [ 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max' @@ -97,23 +101,48 @@ _quant_config_default = { } -# TODO: Hard-code, remove it when Paddle 2.3.1 -class OutScaleForTrainingPassV2(OutScaleForTrainingPass): - def __init__(self, scope=None, place=None, moving_rate=0.9): - OutScaleForTrainingPass.__init__( - self, scope=scope, place=place, moving_rate=moving_rate) - - def _scale_name(self, var_name): +class OutScaleForInferencePassV2(object): + def __init__(self, scope=None): """ - Return the scale name for the var named `var_name`. + This pass is used for setting output scales of some operators. + These output scales may be used by tensorRT or some other inference engines. + + Args: + scope(fluid.Scope): The scope is used to initialize these new parameters. """ - return "%s@scale" % (var_name) + self._scope = scope + self._teller_set = utils._out_scale_op_list + def apply(self, graph): + """ + Get output scales from the scope and set these scales in op_descs + of operators in the teller_set. -# TODO: Hard-code, remove it when Paddle 2.3.1 -class OutScaleForInferencePassV2(OutScaleForInferencePass): - def __init__(self, scope=None): - OutScaleForInferencePass.__init__(self, scope=scope) + Args: + graph(IrGraph): the target graph. + """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + collect_dict = collections.OrderedDict() + op_nodes = graph.all_op_nodes() + for op_node in op_nodes: + if op_node.name() in self._teller_set: + var_names = utils._get_op_output_var_names(op_node) + for var_name in var_names: + in_node = graph._find_node_by_name(op_node.outputs, + var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + + collect_dict[var_name] = {} + scale_name = self._scale_name(var_name) + scale_var = self._scope.find_var(scale_name) + assert scale_var is not None, \ + "Can not find {} variable in the scope".format(scale_name) + scale_value = np.array(scale_var.get_tensor())[0] + collect_dict[var_name]['scale'] = float(scale_value) + return graph, collect_dict def _scale_name(self, var_name): """ @@ -328,7 +357,7 @@ def quant_aware(program, quantizable_op_type=quant_dequant_ops) quant_dequant_pass.apply(main_graph) - out_scale_training_pass = OutScaleForTrainingPassV2( + out_scale_training_pass = OutScaleForTrainingPass( scope=scope, place=place, moving_rate=config['moving_rate']) out_scale_training_pass.apply(main_graph) @@ -361,8 +390,8 @@ def quant_post_static( data_loader=None, model_filename=None, params_filename=None, - save_model_filename='__model__', - save_params_filename='__params__', + save_model_filename='model.pdmodel', + save_params_filename='model.pdiparams', batch_size=1, batch_nums=None, scope=None, @@ -410,9 +439,9 @@ def quant_post_static( When all parameters are saved in a single file, set it as filename. If parameters are saved in separate files, set it as 'None'. Default : 'None'. - save_model_filename(str): The name of model file to save the quantized inference program. Default: '__model__'. + save_model_filename(str): The name of model file to save the quantized inference program. Default: 'model.pdmodel'. save_params_filename(str): The name of file to save all related parameters. - If it is set None, parameters will be saved in separate files. Default: '__params__'. + If it is set None, parameters will be saved in separate files. Default: 'model.pdiparams'. batch_size(int, optional): The batch size of DataLoader, default is 1. batch_nums(int, optional): If batch_nums is not None, the number of calibrate data is 'batch_size*batch_nums'. If batch_nums is None, use all data @@ -513,6 +542,22 @@ def quant_post_static( quantize_model_path, model_filename=save_model_filename, params_filename=save_params_filename) + if onnx_format: + try: + collect_dict = post_training_quantization._calibration_scales + save_quant_table_path = os.path.join(quantize_model_path, + 'calibration_table.txt') + with open(save_quant_table_path, 'w') as txt_file: + for tensor_name in collect_dict.keys(): + write_line = '{} {}'.format( + tensor_name, collect_dict[tensor_name]['scale']) + '\n' + txt_file.write(write_line) + _logger.info("Quantization clip ranges of tensors is save in: {}". + format(save_quant_table_path)) + except: + _logger.warning( + "Unable to generate `calibration_table.txt`, please update PaddlePaddle >= 2.3.3" + ) # We have changed the quant_post to quant_post_static. @@ -521,7 +566,12 @@ def quant_post_static( quant_post = quant_post_static -def convert(program, place, config=None, scope=None, save_int8=False): +def convert(program, + place, + config=None, + scope=None, + save_int8=False, + save_clip_ranges_path='./'): """ convert quantized and well-trained ``program`` to final quantized ``program``that can be used to save ``inference model``. @@ -543,6 +593,7 @@ def convert(program, place, config=None, scope=None, save_int8=False): save_int8: Whether to return ``program`` which model parameters' dtype is ``int8``. This parameter can only be used to get model size. Default: ``False``. + save_clip_ranges_path: If config.onnx_format=True, quantization clip ranges will be saved locally. Returns: Tuple : freezed program which can be used for inference. @@ -563,8 +614,19 @@ def convert(program, place, config=None, scope=None, save_int8=False): if config['onnx_format']: quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass.apply(test_graph) - else: out_scale_infer_pass = OutScaleForInferencePassV2(scope=scope) + _, collect_dict = out_scale_infer_pass.apply(test_graph) + save_quant_table_path = os.path.join(save_clip_ranges_path, + 'calibration_table.txt') + with open(save_quant_table_path, 'w') as txt_file: + for tensor_name in collect_dict.keys(): + write_line = '{} {}'.format( + tensor_name, collect_dict[tensor_name]['scale']) + '\n' + txt_file.write(write_line) + _logger.info("Quantization clip ranges of tensors is save in: {}". + format(save_quant_table_path)) + else: + out_scale_infer_pass = OutScaleForInferencePass(scope=scope) out_scale_infer_pass.apply(test_graph) # Freeze the graph after training by adjusting the quantize # operators' order for the inference. diff --git a/tests/test_quant_post.py b/tests/test_quant_post.py index 39c3a2fe08531c00cc1d42b74269c12f1ea11962..31eed36e2786f6566d26853e7b6f577fa8b2c417 100644 --- a/tests/test_quant_post.py +++ b/tests/test_quant_post.py @@ -132,8 +132,8 @@ class TestQuantAwareCase1(StaticCase): quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model( dirname='./test_quant_post_inference', executor=exe, - model_filename='__model__', - params_filename='__params__') + model_filename='model.pdmodel', + params_filename='model.pdiparams') top1_2, top5_2 = test(quant_post_prog, fetch_targets) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))