diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py index 6b31d428b95faf29e7e0ed1d6ddc8e9dac6077da..202b3eb40f44e9063fce45868e2cf318bdc73a3e 100644 --- a/demo/dygraph/quant/train.py +++ b/demo/dygraph/quant/train.py @@ -55,6 +55,7 @@ add_arg('l2_decay', float, 3e-5, add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.") add_arg('use_pact', bool, False, "Whether to use PACT method.") add_arg('ce_test', bool, False, "Whether to CE test.") +add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('num_epochs', int, 1, "The number of total epochs.") add_arg('total_images', int, 1281167, "The number of total training images.") @@ -359,7 +360,8 @@ def compress(args): input_spec=[ paddle.static.InputSpec( shape=[None, 3, 224, 224], dtype='float32') - ]) + ], + onnx_format=args.onnx_format) def main(): diff --git a/demo/quant/quant_aware/train.py b/demo/quant/quant_aware/train.py index 455d607cecaef2a5269733a2cae05d991b40f939..c42700adf24d62246eb48969c611795cfdf655c6 100644 --- a/demo/quant/quant_aware/train.py +++ b/demo/quant/quant_aware/train.py @@ -41,6 +41,7 @@ add_arg('data', str, "imagenet", "Which data to use. 'mn add_arg('log_period', int, 10, "Log period in batches.") add_arg('checkpoint_dir', str, "output", "checkpoint save dir") add_arg('ce_test', bool, False, "Whether to CE test.") +add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m] @@ -291,7 +292,8 @@ def compress(args): ############################################################################################################ float_program, int8_program = convert(val_program, place, quant_config, \ scope=None, \ - save_int8=True) + save_int8=True, + onnx_format=args.onnx_format) print("eval best_model after convert") final_acc1 = test(best_epoch, float_program) ############################################################################################################ diff --git a/demo/quant/quant_post/README.md b/demo/quant/quant_post/README.md index 0f442bd46a358c0c51b8b572ef0bd673d34d6a87..2220cebc8452c55cb75ac5f880faa37037a0613f 100755 --- a/demo/quant/quant_post/README.md +++ b/demo/quant/quant_post/README.md @@ -1,6 +1,6 @@ # 静态离线量化示例 -本示例将介绍如何使用离线量化接口``paddleslim.quant.quant_post_static``来对训练好的分类模型进行离线量化, 无需对模型进行训练即可得到量化模型,减少模型的存储空间和显存占用。 +本示例将介绍如何使用离线量化接口``paddleslim.quant.quant_post_static``来对训练好的分类模型进行离线量化, 无需对模型进行训练即可得到量化模型,减少模型的存储空间和显存占用。 本demo中模型均从[PaddleClas模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md) 中下载。 ## 接口介绍 @@ -8,6 +8,10 @@ ## 分类模型的离线量化流程 +### 环境准备 + +PaddlePaddle >= 2.3 或develop版本 + ### 准备数据 在``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: @@ -17,56 +21,78 @@ - ``'val_list.txt'``文件 ### 准备需要量化的模型 -离线量化接口只支持加载通过``paddle.static.save_inference_model``接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明。 +离线量化接口支持加载通过``paddle.static.save_inference_model``接口或者`paddle.jit.save`保存的静态图Inference模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。 -首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型。 +图像分类的Inference模型均可从[PaddleClas模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md)的表格中下载得到。 -在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained`` +- MobileNetV1模型准备: +``` +wget -P inference_model https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar +cd inference_model/ +tar -xf MobileNetV1_infer.tar +``` -### 导出模型 -通过运行以下命令可将模型转化为离线量化接口可用的模型: +- ResNet50模型准备: ``` -python export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet +wget -P inference_model https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar +cd inference_model/ +tar -xf ResNet50_infer.tar ``` -转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。 ### 静态离线量化 接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为[quant_post.py](./quant_post.py),脚本中使用接口``paddleslim.quant.quant_post_static``对模型进行离线量化。运行命令为: + ``` -python quant_post.py --model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights +# MobileNetV1 +python quant_post.py --model_path ./inference_model/MobileNetV1_infer/ --save_path ./quant_model/MobileNet +# ResNet50 +python quant_post.py --model_path ./inference_model/ResNet50_infer/ --save_path ./quant_model/ResNet50 ``` -- ``model_path``: 需要量化的模型所在路径 -- ``save_path``: 量化后的模型保存的路径 -- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。 -- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。 +- 参数列表: -运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。 +| 参数名 | 解释 | +| :-------- | :--------: | +| model_path | 需要量化的模型所在路径 | +| save_path | 量化后的模型保存的路径 | +| model_filename | 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。 | +| params_filename | 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。 | +| algo | 激活量化使用的算法,默认是`hist` | +| batch_size | 模型校准使用的batch size大小 | +| batch_num | 模型校准时的总batch数量 | +| round_type | 模型量化时四舍五入的方法,可选择`round`和`adaround`,默认是`round` | +| onnx_format | 保存量化模型时的格式是否是ONNX通配格式,默认False | +| is_full_quantize | 是否对模型进行全量化 | +| input_name | 量化时模型输入的name,如果使用PaddleClas模型库中下载好的模型,保持默认为inputs,如果是自己导出模型,应设置:`--input_name='x'`,可用VisualDL或Netron查看模型输入正确name | -> 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正。 +运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。 ### 测试精度 使用[eval.py](./eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。 -首先测试量化前的模型的精度,运行以下命令: -``` -python eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights -``` -精度输出为: -``` -top1_acc/top5_acc= [0.70913923 0.89548034] +- 首先测试量化前的模型的精度,运行以下命令: +```shell +# MobileNetV1 +python eval.py --model_path=./inference_model/MobileNetV1_infer --model_name=inference.pdmodel --params_name=inference.pdiparams +# ResNet50 +python eval.py --model_path=./inference_model/ResNet50_infer --model_name=inference.pdmodel --params_name=inference.pdiparams ``` -使用以下命令测试离线量化后的模型的精度: +- 测试离线量化后的模型的精度: -``` -python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__ --params_name __params__ +```shell +# MobileNetV1 +python eval.py --model_path ./quant_model/MobileNet/ +# ResNet50 +python eval.py --model_path ./quant_model/ResNet50/ ``` -精度输出为 -``` -top1_acc/top5_acc= [0.70328485 0.89183184] -``` -从以上精度对比可以看出,对``mobilenet``在``imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%``, ``top5``精度损失为``0.36%``. + +### benchmark + +| 模型 | FP32 acc-top1 | INT8 acc-top1 | INT8 acc(adaround) | +| :-------- | :--------: | :--------: | :--------: | +| MobileNetV1 | 0.7092 | 0.7036 | 0.7063 | +| ResNet50 | 0.7633 | 0.7615 | 0.7625 | diff --git a/demo/quant/quant_post/eval.py b/demo/quant/quant_post/eval.py index 02020424cce4b12426451ddef681610ad617b472..3547c4fa55fef8882909872e3c0d6a8c015e5475 100755 --- a/demo/quant/quant_post/eval.py +++ b/demo/quant/quant_post/eval.py @@ -29,8 +29,8 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.") -add_arg('model_name', str, None, "model filename for inference model") -add_arg('params_name', str, None, "params filename for inference model") +add_arg('model_name', str, '__model__', "model filename for inference model") +add_arg('params_name', str, '__params__', "params filename for inference model") add_arg('batch_size', int, 64, "Minibatch size.") # yapf: enable diff --git a/demo/quant/quant_post/export_model.py b/demo/quant/quant_post/export_model.py deleted file mode 100755 index e8b16db54cb3b93b8ef2b9b2f67524e2b4407847..0000000000000000000000000000000000000000 --- a/demo/quant/quant_post/export_model.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import sys -import logging -import paddle -import argparse -import functools -import math -import time -import numpy as np -sys.path[0] = os.path.join( - os.path.dirname("__file__"), os.path.pardir, os.path.pardir) -from paddleslim.common import get_logger -import models -from utility import add_arguments, print_arguments - -_logger = get_logger(__name__, level=logging.INFO) - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('model', str, "MobileNet", "The target model.") -add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.") -add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") -add_arg('test_period', int, 10, "Test period in epoches.") -# yapf: enable - -model_list = [m for m in dir(models) if "__" not in m] - - -def export_model(args): - if args.data == "mnist": - import paddle.dataset.mnist as reader - train_reader = reader.train() - val_reader = reader.test() - class_dim = 10 - image_shape = "1,28,28" - elif args.data == "imagenet": - import imagenet_reader as reader - train_reader = reader.train() - val_reader = reader.val() - class_dim = 1000 - image_shape = "3,224,224" - else: - raise ValueError("{} is not supported.".format(args.data)) - - image_shape = [int(m) for m in image_shape.split(",")] - image = paddle.static.data( - name='image', shape=[None] + image_shape, dtype='float32') - assert args.model in model_list, "{} is not in lists: {}".format(args.model, - model_list) - # model definition - model = models.__dict__[args.model]() - out = model.net(input=image, class_dim=class_dim) - val_program = paddle.static.default_main_program().clone(for_test=True) - place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - - if args.pretrained_model: - paddle.static.load(val_program, args.pretrained_model, exe) - else: - assert False, "args.pretrained_model must set" - - paddle.fluid.io.save_inference_model( - './inference_model/' + args.model, - feeded_var_names=[image.name], - target_vars=[out], - executor=exe, - main_program=val_program, - model_filename='model', - params_filename='weights') - - -def main(): - args = parser.parse_args() - print_arguments(args) - export_model(args) - - -if __name__ == '__main__': - paddle.enable_static() - main() diff --git a/demo/quant/quant_post/quant_post.py b/demo/quant/quant_post/quant_post.py index c7a682dfbbde2b9956267541938ccd6b8ea643bc..2e3f7b40f9a8afda86cf8bbd9c001816a8e2ea04 100755 --- a/demo/quant/quant_post/quant_post.py +++ b/demo/quant/quant_post/quant_post.py @@ -24,15 +24,18 @@ add_arg = functools.partial(add_arguments, argparser=parser) add_arg('batch_size', int, 32, "Minibatch size.") add_arg('batch_num', int, 1, "Batch number") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('model_path', str, "./inference_model/MobileNet/", "model dir") +add_arg('model_path', str, "./inference_model/MobileNetV1_infer/", "model dir") add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model") -add_arg('model_filename', str, None, "model file name") -add_arg('params_filename', str, None, "params file name") +add_arg('model_filename', str, 'inference.pdmodel', "model file name") +add_arg('params_filename', str, 'inference.pdiparams', "params file name") add_arg('algo', str, 'hist', "calibration algorithm") add_arg('round_type', str, 'round', "The method of converting the quantized weights.") add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist") +add_arg('is_full_quantize', bool, False, "Whether is full quantization or not.") add_arg('bias_correction', bool, False, "Whether to use bias correction") add_arg('ce_test', bool, False, "Whether to CE test.") +add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.") +add_arg('input_name', str, 'inputs', "The name of model input.") # yapf: enable @@ -51,7 +54,7 @@ def quantize(args): val_dataset = reader.ImageNetDataset(mode='test') image_shape = [3, 224, 224] image = paddle.static.data( - name='image', shape=[None] + image_shape, dtype='float32') + name=args.input_name, shape=[None] + image_shape, dtype='float32') data_loader = paddle.io.DataLoader( val_dataset, places=place, @@ -77,7 +80,9 @@ def quantize(args): algo=args.algo, round_type=args.round_type, hist_percent=args.hist_percent, - bias_correction=args.bias_correction) + is_full_quantize=args.is_full_quantize, + bias_correction=args.bias_correction, + onnx_format=args.onnx_format) def main(): diff --git a/paddleslim/auto_compression/utils/fake_ptq.py b/paddleslim/auto_compression/utils/fake_ptq.py index 80beeac46f727a03aad6c7c35221d9da662eafc8..83d3600689b9a808564cdc85dc30abd7098d4819 100644 --- a/paddleslim/auto_compression/utils/fake_ptq.py +++ b/paddleslim/auto_compression/utils/fake_ptq.py @@ -3,6 +3,14 @@ from paddle.fluid.framework import IrGraph from paddle.fluid import core from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass +try: + from paddle.fluid.contrib.slim.quantization import utils + TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type + QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type +except: + TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type + QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + def post_quant_fake(executor, model_dir, @@ -29,8 +37,8 @@ def post_quant_fake(executor, activation_quantize_type = 'range_abs_max' weight_quantize_type = 'channel_wise_abs_max' _dynamic_quantize_op_type = ['lstm'] - _weight_supported_quantizable_op_type = QuantizationTransformPass._supported_quantizable_op_type - _act_supported_quantizable_op_type = AddQuantDequantPass._supported_quantizable_op_type + _weight_supported_quantizable_op_type = TRANSFORM_PASS_OP_TYPES + _act_supported_quantizable_op_type = QUANT_DEQUANT_PASS_OP_TYPES _support_quantize_op_type = list( set(_weight_supported_quantizable_op_type + _act_supported_quantizable_op_type + _dynamic_quantize_op_type)) diff --git a/paddleslim/dygraph/quant/qat.py b/paddleslim/dygraph/quant/qat.py index a7a85164116a45a54e5b87063fcbbf434c12996d..66ff295f6922f70477e4ab8f6f10e55e12644b88 100644 --- a/paddleslim/dygraph/quant/qat.py +++ b/paddleslim/dygraph/quant/qat.py @@ -232,7 +232,11 @@ class QAT(object): return quant_model - def save_quantized_model(self, model, path, input_spec=None): + def save_quantized_model(self, + model, + path, + input_spec=None, + onnx_format=False): """ Save the quantized inference model. @@ -258,7 +262,10 @@ class QAT(object): model.eval() self.imperative_qat.save_quantized_model( - layer=model, path=path, input_spec=input_spec) + layer=model, + path=path, + input_spec=input_spec, + onnx_format=onnx_format) def _remove_preprocess(self, model): state_dict = model.state_dict() diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index b77efa15bb18a22aed1dbd497bc37770aacc4690..a982529b91ebd9564e1e42ee20cb2e4469a21a8f 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -27,6 +27,12 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass +try: + from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2 + from paddle.fluid.contrib.slim.quantization import QuantWeightPass + from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2 +except: + pass from paddle.fluid import core from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.layer_helper import LayerHelper @@ -48,8 +54,13 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [ ] VALID_DTYPES = ['int8'] -TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type -QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type +try: + from paddle.fluid.contrib.slim.quantization import utils + TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type + QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type +except: + TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type + QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type TENSORRT_OP_TYPES = [ 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add', @@ -186,6 +197,7 @@ def quant_aware(program, act_preprocess_func=None, optimizer_func=None, executor=None, + onnx_format=False, return_program=False): """Add quantization and dequantization operators to "program" for quantization training or testing. @@ -251,7 +263,8 @@ def quant_aware(program, elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: quant_dequant_ops.append(op_type) if len(transform_pass_ops) > 0: - transform_pass = QuantizationTransformPass( + trannsform_func = 'QuantizationTransformPassV2' if onnx_format else 'QuantizationTransformPass' + transform_pass = eval(trannsform_func)( scope=scope, place=place, weight_bits=config['weight_bits'], @@ -272,7 +285,8 @@ def quant_aware(program, transform_pass.apply(main_graph) if len(quant_dequant_ops) > 0: - quant_dequant_pass = AddQuantDequantPass( + qdq_func = 'AddQuantDequantPassV2' if onnx_format else 'AddQuantDequantPass' + quant_dequant_pass = eval(qdq_func)( scope=scope, place=place, moving_rate=config['moving_rate'], @@ -335,6 +349,7 @@ def quant_post_static( activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', optimize_model=False, + onnx_format=False, is_use_cache_file=False, cache_dir="./temp_post_training"): """ @@ -433,6 +448,7 @@ def quant_post_static( activation_bits=activation_bits, activation_quantize_type=activation_quantize_type, weight_quantize_type=weight_quantize_type, + onnx_format=onnx_format, optimize_model=optimize_model) post_training_quantization.quantize() post_training_quantization.save_quantized_model( @@ -447,7 +463,12 @@ def quant_post_static( quant_post = quant_post_static -def convert(program, place, config=None, scope=None, save_int8=False): +def convert(program, + place, + config=None, + scope=None, + save_int8=False, + onnx_format=False): """ convert quantized and well-trained ``program`` to final quantized ``program``that can be used to save ``inference model``. @@ -486,22 +507,24 @@ def convert(program, place, config=None, scope=None, save_int8=False): _logger.info("convert config {}".format(config)) test_graph = IrGraph(core.Graph(program.desc), for_test=True) - out_scale_infer_pass = OutScaleForInferencePass(scope=scope) - out_scale_infer_pass.apply(test_graph) - - # Freeze the graph after training by adjusting the quantize - # operators' order for the inference. - freeze_pass = QuantizationFreezePass( - scope=scope, - place=place, - weight_bits=config['weight_bits'], - activation_bits=config['activation_bits'], - weight_quantize_type=config['weight_quantize_type']) - - if os.path.exists(VARS_MAPPING_TABLE): - test_graph.out_node_mapping_table = load_dict() + if onnx_format: + quant_weight_pass = QuantWeightPass(scope, place) + quant_weight_pass.apply(test_graph) + else: + out_scale_infer_pass = OutScaleForInferencePass(scope=scope) + out_scale_infer_pass.apply(test_graph) + # Freeze the graph after training by adjusting the quantize + # operators' order for the inference. + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + weight_quantize_type=config['weight_quantize_type']) + if os.path.exists(VARS_MAPPING_TABLE): + test_graph.out_node_mapping_table = load_dict() + freeze_pass.apply(test_graph) - freeze_pass.apply(test_graph) freezed_program = test_graph.to_program() if save_int8: