未验证 提交 83869609 编写于 作者: G Guanghua Yu 提交者: GitHub

add new format of quantization (#1029)

上级 1bea8e18
......@@ -55,6 +55,7 @@ add_arg('l2_decay', float, 3e-5,
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
add_arg('use_pact', bool, False, "Whether to use PACT method.")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
......@@ -359,7 +360,8 @@ def compress(args):
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
])
],
onnx_format=args.onnx_format)
def main():
......
......@@ -41,6 +41,7 @@ add_arg('data', str, "imagenet", "Which data to use. 'mn
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('checkpoint_dir', str, "output", "checkpoint save dir")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
......@@ -291,7 +292,8 @@ def compress(args):
############################################################################################################
float_program, int8_program = convert(val_program, place, quant_config, \
scope=None, \
save_int8=True)
save_int8=True,
onnx_format=args.onnx_format)
print("eval best_model after convert")
final_acc1 = test(best_epoch, float_program)
############################################################################################################
......
# 静态离线量化示例
本示例将介绍如何使用离线量化接口``paddleslim.quant.quant_post_static``来对训练好的分类模型进行离线量化, 无需对模型进行训练即可得到量化模型,减少模型的存储空间和显存占用。
本示例将介绍如何使用离线量化接口``paddleslim.quant.quant_post_static``来对训练好的分类模型进行离线量化, 无需对模型进行训练即可得到量化模型,减少模型的存储空间和显存占用。 本demo中模型均从[PaddleClas模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md) 中下载。
## 接口介绍
......@@ -8,6 +8,10 @@
## 分类模型的离线量化流程
### 环境准备
PaddlePaddle >= 2.3 或develop版本
### 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
......@@ -17,56 +21,78 @@
- ``'val_list.txt'``文件
### 准备需要量化的模型
离线量化接口只支持加载通过``paddle.static.save_inference_model``接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明
离线量化接口支持加载通过``paddle.static.save_inference_model``接口或者`paddle.jit.save`保存的静态图Inference模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化
首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型
图像分类的Inference模型均可从[PaddleClas模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md)的表格中下载得到
在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained``
- MobileNetV1模型准备:
```
wget -P inference_model https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
cd inference_model/
tar -xf MobileNetV1_infer.tar
```
### 导出模型
通过运行以下命令可将模型转化为离线量化接口可用的模型:
- ResNet50模型准备:
```
python export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
wget -P inference_model https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar
cd inference_model/
tar -xf ResNet50_infer.tar
```
转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。
### 静态离线量化
接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为[quant_post.py](./quant_post.py),脚本中使用接口``paddleslim.quant.quant_post_static``对模型进行离线量化。运行命令为:
```
python quant_post.py --model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights
# MobileNetV1
python quant_post.py --model_path ./inference_model/MobileNetV1_infer/ --save_path ./quant_model/MobileNet
# ResNet50
python quant_post.py --model_path ./inference_model/ResNet50_infer/ --save_path ./quant_model/ResNet50
```
- ``model_path``: 需要量化的模型所在路径
- ``save_path``: 量化后的模型保存的路径
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
- 参数列表:
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
| 参数名 | 解释 |
| :-------- | :--------: |
| model_path | 需要量化的模型所在路径 |
| save_path | 量化后的模型保存的路径 |
| model_filename | 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。 |
| params_filename | 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。 |
| algo | 激活量化使用的算法,默认是`hist` |
| batch_size | 模型校准使用的batch size大小 |
| batch_num | 模型校准时的总batch数量 |
| round_type | 模型量化时四舍五入的方法,可选择`round``adaround`,默认是`round` |
| onnx_format | 保存量化模型时的格式是否是ONNX通配格式,默认False |
| is_full_quantize | 是否对模型进行全量化 |
| input_name | 量化时模型输入的name,如果使用PaddleClas模型库中下载好的模型,保持默认为inputs,如果是自己导出模型,应设置:`--input_name='x'`,可用VisualDL或Netron查看模型输入正确name |
> 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正。
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
### 测试精度
使用[eval.py](./eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。
首先测试量化前的模型的精度,运行以下命令:
```
python eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
```
精度输出为:
```
top1_acc/top5_acc= [0.70913923 0.89548034]
- 首先测试量化前的模型的精度,运行以下命令:
```shell
# MobileNetV1
python eval.py --model_path=./inference_model/MobileNetV1_infer --model_name=inference.pdmodel --params_name=inference.pdiparams
# ResNet50
python eval.py --model_path=./inference_model/ResNet50_infer --model_name=inference.pdmodel --params_name=inference.pdiparams
```
使用以下命令测试离线量化后的模型的精度:
- 测试离线量化后的模型的精度:
```
python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__ --params_name __params__
```shell
# MobileNetV1
python eval.py --model_path ./quant_model/MobileNet/
# ResNet50
python eval.py --model_path ./quant_model/ResNet50/
```
精度输出为
```
top1_acc/top5_acc= [0.70328485 0.89183184]
```
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%````top5``精度损失为``0.36%``.
### benchmark
| 模型 | FP32 acc-top1 | INT8 acc-top1 | INT8 acc(adaround) |
| :-------- | :--------: | :--------: | :--------: |
| MobileNetV1 | 0.7092 | 0.7036 | 0.7063 |
| ResNet50 | 0.7633 | 0.7615 | 0.7625 |
......@@ -29,8 +29,8 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.")
add_arg('model_name', str, None, "model filename for inference model")
add_arg('params_name', str, None, "params filename for inference model")
add_arg('model_name', str, '__model__', "model filename for inference model")
add_arg('params_name', str, '__params__', "params filename for inference model")
add_arg('batch_size', int, 64, "Minibatch size.")
# yapf: enable
......
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('test_period', int, 10, "Test period in epoches.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def export_model(args):
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
val_program = paddle.static.default_main_program().clone(for_test=True)
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
paddle.static.load(val_program, args.pretrained_model, exe)
else:
assert False, "args.pretrained_model must set"
paddle.fluid.io.save_inference_model(
'./inference_model/' + args.model,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=val_program,
model_filename='model',
params_filename='weights')
def main():
args = parser.parse_args()
print_arguments(args)
export_model(args)
if __name__ == '__main__':
paddle.enable_static()
main()
......@@ -24,15 +24,18 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('batch_num', int, 1, "Batch number")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
add_arg('model_path', str, "./inference_model/MobileNetV1_infer/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, None, "model file name")
add_arg('params_filename', str, None, "params file name")
add_arg('model_filename', str, 'inference.pdmodel', "model file name")
add_arg('params_filename', str, 'inference.pdiparams', "params file name")
add_arg('algo', str, 'hist', "calibration algorithm")
add_arg('round_type', str, 'round', "The method of converting the quantized weights.")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
add_arg('is_full_quantize', bool, False, "Whether is full quantization or not.")
add_arg('bias_correction', bool, False, "Whether to use bias correction")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('onnx_format', bool, False, "Whether to export the quantized model with format of ONNX.")
add_arg('input_name', str, 'inputs', "The name of model input.")
# yapf: enable
......@@ -51,7 +54,7 @@ def quantize(args):
val_dataset = reader.ImageNetDataset(mode='test')
image_shape = [3, 224, 224]
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
name=args.input_name, shape=[None] + image_shape, dtype='float32')
data_loader = paddle.io.DataLoader(
val_dataset,
places=place,
......@@ -77,7 +80,9 @@ def quantize(args):
algo=args.algo,
round_type=args.round_type,
hist_percent=args.hist_percent,
bias_correction=args.bias_correction)
is_full_quantize=args.is_full_quantize,
bias_correction=args.bias_correction,
onnx_format=args.onnx_format)
def main():
......
......@@ -3,6 +3,14 @@ from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass
try:
from paddle.fluid.contrib.slim.quantization import utils
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
def post_quant_fake(executor,
model_dir,
......@@ -29,8 +37,8 @@ def post_quant_fake(executor,
activation_quantize_type = 'range_abs_max'
weight_quantize_type = 'channel_wise_abs_max'
_dynamic_quantize_op_type = ['lstm']
_weight_supported_quantizable_op_type = QuantizationTransformPass._supported_quantizable_op_type
_act_supported_quantizable_op_type = AddQuantDequantPass._supported_quantizable_op_type
_weight_supported_quantizable_op_type = TRANSFORM_PASS_OP_TYPES
_act_supported_quantizable_op_type = QUANT_DEQUANT_PASS_OP_TYPES
_support_quantize_op_type = list(
set(_weight_supported_quantizable_op_type +
_act_supported_quantizable_op_type + _dynamic_quantize_op_type))
......
......@@ -232,7 +232,11 @@ class QAT(object):
return quant_model
def save_quantized_model(self, model, path, input_spec=None):
def save_quantized_model(self,
model,
path,
input_spec=None,
onnx_format=False):
"""
Save the quantized inference model.
......@@ -258,7 +262,10 @@ class QAT(object):
model.eval()
self.imperative_qat.save_quantized_model(
layer=model, path=path, input_spec=input_spec)
layer=model,
path=path,
input_spec=input_spec,
onnx_format=onnx_format)
def _remove_preprocess(self, model):
state_dict = model.state_dict()
......
......@@ -27,6 +27,12 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
try:
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import QuantWeightPass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
except:
pass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.layer_helper import LayerHelper
......@@ -48,8 +54,13 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
]
VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
try:
from paddle.fluid.contrib.slim.quantization import utils
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
......@@ -186,6 +197,7 @@ def quant_aware(program,
act_preprocess_func=None,
optimizer_func=None,
executor=None,
onnx_format=False,
return_program=False):
"""Add quantization and dequantization operators to "program"
for quantization training or testing.
......@@ -251,7 +263,8 @@ def quant_aware(program,
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0:
transform_pass = QuantizationTransformPass(
trannsform_func = 'QuantizationTransformPassV2' if onnx_format else 'QuantizationTransformPass'
transform_pass = eval(trannsform_func)(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
......@@ -272,7 +285,8 @@ def quant_aware(program,
transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPass(
qdq_func = 'AddQuantDequantPassV2' if onnx_format else 'AddQuantDequantPass'
quant_dequant_pass = eval(qdq_func)(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
......@@ -335,6 +349,7 @@ def quant_post_static(
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
onnx_format=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
"""
......@@ -433,6 +448,7 @@ def quant_post_static(
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
onnx_format=onnx_format,
optimize_model=optimize_model)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(
......@@ -447,7 +463,12 @@ def quant_post_static(
quant_post = quant_post_static
def convert(program, place, config=None, scope=None, save_int8=False):
def convert(program,
place,
config=None,
scope=None,
save_int8=False,
onnx_format=False):
"""
convert quantized and well-trained ``program`` to final quantized
``program``that can be used to save ``inference model``.
......@@ -486,22 +507,24 @@ def convert(program, place, config=None, scope=None, save_int8=False):
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
if onnx_format:
quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph)
else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph)
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
if save_int8:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册