未验证 提交 846f6d5c 编写于 作者: G Guanghua Yu 提交者: GitHub

add post quant (#4255)

上级 f0bb99d8
......@@ -4,10 +4,11 @@
- [剪裁](prune)
- [量化](quant)
- [离线量化](post_quant)
- [蒸馏](distill)
- [联合策略](extensions)
推荐您使用剪裁和蒸馏联合训练,或者使用剪裁量化,进行检测模型压缩。 下面以YOLOv3为例,进行剪裁、蒸馏和量化实验。
推荐您使用剪裁和蒸馏联合训练,或者使用剪裁、量化训练和离线量化,进行检测模型压缩。 下面以YOLOv3为例,进行剪裁、蒸馏和量化实验。
## 实验环境
......@@ -20,7 +21,8 @@
**PaddleDetection、 PaddlePaddle与PaddleSlim 版本关系:**
| PaddleDetection版本 | PaddlePaddle版本 | PaddleSlim版本 | 备注 |
| :------------------: | :---------------: | :-------: |:---------------: |
| release/2.1 | >= 2.1.0 | 2.1 | 量化模型导出依赖最新Paddle develop分支,可在[PaddlePaddle每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)中下载安装 |
| release/2.3 | >= 2.1 | 2.1 | 离线量化依赖Paddle 2.2及PaddleSlim 2.2 |
| release/2.1 | 2.2 | >= 2.1.0 | 2.1 | 量化模型导出依赖最新Paddle develop分支,可在[PaddlePaddle每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)中下载安装 |
| release/2.0 | >= 2.0.1 | 2.0 | 量化依赖Paddle 2.1及PaddleSlim 2.1 |
......@@ -145,6 +147,16 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{
- 上述V100预测时延非量化模型均是使用TensorRT-FP32测试,量化模型均使用TensorRT-INT8测试,并且都包含NMS耗时。
- SD855预测时延为使用PaddleLite部署,使用arm8架构并使用4线程(4 Threads)推理时延。
### 离线量化
需要准备val集,用来对离线量化模型进行校准,运行方式:
```shell
python tools/post_quant.py -c configs/{MODEL.yml} --slim_config configs/slim/post_quant/{SLIM_CONFIG.yml}
```
例如:
```shell
python3.7 tools/post_quant.py -c configs/ppyolo/ppyolo_mbv3_large_coco.yml --slim_config=configs/slim/post_quant/ppyolo_mbv3_large_ptq.yml
```
### 蒸馏
#### COCO上benchmark
......
weights: https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams
slim: PTQ
PTQ:
ptq_config: {
'activation_quantizer': 'HistQuantizer',
'upsample_bins': 127,
'hist_percent': 0.999}
quant_batch_num: 10
fuse: True
......@@ -20,6 +20,7 @@ import os
import yaml
from collections import OrderedDict
import paddle
from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger
......@@ -50,6 +51,24 @@ KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
paddle.enable_static()
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()
for name, spec in input_spec[0].items():
try:
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
return pruned_input_spec
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = []
......@@ -97,7 +116,7 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state = False
from ppdet.core.config.yaml_helpers import setup_orderdict
setup_orderdict()
use_dynamic_shape = True if image_shape[1] == -1 else False
use_dynamic_shape = True if image_shape[2] == -1 else False
infer_cfg = OrderedDict({
'mode': 'fluid',
'draw_threshold': 0.5,
......@@ -141,7 +160,7 @@ def _dump_infer_config(config, path, image_shape, model):
dataset_cfg = config['TestDataset']
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape)
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(os.path.join(path)))
......@@ -41,7 +41,7 @@ import ppdet.utils.stats as stats
from ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
from .export_utils import _dump_infer_config
from .export_utils import _dump_infer_config, _prune_input_spec
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
......@@ -541,12 +541,7 @@ class Trainer(object):
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
image_shape = None
if self.cfg.architecture in MOT_ARCH:
test_reader_name = 'TestMOTReader'
......@@ -555,9 +550,11 @@ class Trainer(object):
if 'inputs_def' in self.cfg[test_reader_name]:
inputs_def = self.cfg[test_reader_name]['inputs_def']
image_shape = inputs_def.get('image_shape', None)
# set image_shape=[3, -1, -1] as default
# set image_shape=[None, 3, -1, -1] as default
if image_shape is None:
image_shape = [3, -1, -1]
image_shape = [None, 3, -1, -1]
if len(image_shape) == 3:
image_shape = [None] + image_shape
if hasattr(self.model, 'deploy'):
self.model.deploy = True
......@@ -574,7 +571,7 @@ class Trainer(object):
input_spec = [{
"image": InputSpec(
shape=[None] + image_shape, name='image'),
shape=image_shape, name='image'),
"im_shape": InputSpec(
shape=[None, 2], name='im_shape'),
"scale_factor": InputSpec(
......@@ -585,13 +582,29 @@ class Trainer(object):
"crops": InputSpec(
shape=[None, 3, 192, 64], name='crops')
})
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
if prune_input:
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec(
pruned_input_spec = _prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
else:
static_model = None
pruned_input_spec = input_spec
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir)
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
......@@ -606,22 +619,26 @@ class Trainer(object):
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
paddle.enable_static()
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()
for name, spec in input_spec[0].items():
try:
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
return pruned_input_spec
def post_quant(self, output_dir='output_inference'):
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for idx, data in enumerate(self.loader):
self.model(data)
if idx == int(self.cfg.get('quant_batch_num', 10)):
break
# TODO: support prune input_spec
_, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir, prune_input=False)
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export Post-Quant model and saved in {}".format(save_dir))
def _flops(self, loader):
self.model.eval()
......
......@@ -48,6 +48,14 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
load_pretrain_weight(model, weights)
cfg['model'] = model
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'PTQ':
model = create(cfg.architecture)
load_config(slim_cfg)
load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model)
cfg['slim'] = slim
else:
load_config(slim_cfg)
model = create(cfg.architecture)
......
......@@ -49,3 +49,36 @@ class QAT(object):
def save_quantized_model(self, layer, path, input_spec=None, **config):
self.quanter.save_quantized_model(
model=layer, path=path, input_spec=input_spec, **config)
@register
@serializable
class PTQ(object):
def __init__(self,
ptq_config,
quant_batch_num=10,
output_dir='output_inference',
fuse=True,
fuse_list=None):
super(PTQ, self).__init__()
self.ptq_config = ptq_config
self.quant_batch_num = quant_batch_num
self.output_dir = output_dir
self.fuse = fuse
self.fuse_list = fuse_list
def __call__(self, model):
paddleslim = try_import('paddleslim')
self.ptq = paddleslim.PTQ(**self.ptq_config)
model.eval()
quant_model = self.ptq.quantize(
model, fuse=self.fuse, fuse_list=self.fuse_list)
return quant_model
def save_quantized_model(self,
quant_model,
quantize_model_path,
input_spec=None):
self.ptq.save_quantized_model(quant_model, quantize_model_path,
input_spec)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# ignore warning log
import warnings
warnings.filterwarnings('ignore')
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer
from ppdet.slim import build_slim_model
from ppdet.utils.logger import setup_logger
logger = setup_logger('post_quant')
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--output_dir",
type=str,
default="output_inference",
help="Directory for storing the output model files.")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
# build detector
trainer = Trainer(cfg, mode='eval')
# load weights
if cfg.architecture in ['DeepSORT']:
if cfg.det_weights != 'None':
trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
else:
trainer.load_weights_sde(None, cfg.reid_weights)
else:
trainer.load_weights(cfg.weights)
# post quant model
trainer.post_quant(FLAGS.output_dir)
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
# TODO: to be refined in the future
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
FLAGS.opt['norm_type'] = 'bn'
merge_config(FLAGS.opt)
if FLAGS.slim_config:
cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
# FIXME: Temporarily solve the priority problem of FLAGS.opt
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册