未验证 提交 846f6d5c 编写于 作者: G Guanghua Yu 提交者: GitHub

add post quant (#4255)

上级 f0bb99d8
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
- [剪裁](prune) - [剪裁](prune)
- [量化](quant) - [量化](quant)
- [离线量化](post_quant)
- [蒸馏](distill) - [蒸馏](distill)
- [联合策略](extensions) - [联合策略](extensions)
推荐您使用剪裁和蒸馏联合训练,或者使用剪裁量化,进行检测模型压缩。 下面以YOLOv3为例,进行剪裁、蒸馏和量化实验。 推荐您使用剪裁和蒸馏联合训练,或者使用剪裁、量化训练和离线量化,进行检测模型压缩。 下面以YOLOv3为例,进行剪裁、蒸馏和量化实验。
## 实验环境 ## 实验环境
...@@ -20,7 +21,8 @@ ...@@ -20,7 +21,8 @@
**PaddleDetection、 PaddlePaddle与PaddleSlim 版本关系:** **PaddleDetection、 PaddlePaddle与PaddleSlim 版本关系:**
| PaddleDetection版本 | PaddlePaddle版本 | PaddleSlim版本 | 备注 | | PaddleDetection版本 | PaddlePaddle版本 | PaddleSlim版本 | 备注 |
| :------------------: | :---------------: | :-------: |:---------------: | | :------------------: | :---------------: | :-------: |:---------------: |
| release/2.1 | >= 2.1.0 | 2.1 | 量化模型导出依赖最新Paddle develop分支,可在[PaddlePaddle每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)中下载安装 | | release/2.3 | >= 2.1 | 2.1 | 离线量化依赖Paddle 2.2及PaddleSlim 2.2 |
| release/2.1 | 2.2 | >= 2.1.0 | 2.1 | 量化模型导出依赖最新Paddle develop分支,可在[PaddlePaddle每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)中下载安装 |
| release/2.0 | >= 2.0.1 | 2.0 | 量化依赖Paddle 2.1及PaddleSlim 2.1 | | release/2.0 | >= 2.0.1 | 2.0 | 量化依赖Paddle 2.1及PaddleSlim 2.1 |
...@@ -145,6 +147,16 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{ ...@@ -145,6 +147,16 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{
- 上述V100预测时延非量化模型均是使用TensorRT-FP32测试,量化模型均使用TensorRT-INT8测试,并且都包含NMS耗时。 - 上述V100预测时延非量化模型均是使用TensorRT-FP32测试,量化模型均使用TensorRT-INT8测试,并且都包含NMS耗时。
- SD855预测时延为使用PaddleLite部署,使用arm8架构并使用4线程(4 Threads)推理时延。 - SD855预测时延为使用PaddleLite部署,使用arm8架构并使用4线程(4 Threads)推理时延。
### 离线量化
需要准备val集,用来对离线量化模型进行校准,运行方式:
```shell
python tools/post_quant.py -c configs/{MODEL.yml} --slim_config configs/slim/post_quant/{SLIM_CONFIG.yml}
```
例如:
```shell
python3.7 tools/post_quant.py -c configs/ppyolo/ppyolo_mbv3_large_coco.yml --slim_config=configs/slim/post_quant/ppyolo_mbv3_large_ptq.yml
```
### 蒸馏 ### 蒸馏
#### COCO上benchmark #### COCO上benchmark
......
weights: https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams
slim: PTQ
PTQ:
ptq_config: {
'activation_quantizer': 'HistQuantizer',
'upsample_bins': 127,
'hist_percent': 0.999}
quant_batch_num: 10
fuse: True
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import yaml import yaml
from collections import OrderedDict from collections import OrderedDict
import paddle
from ppdet.data.source.category import get_categories from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -50,6 +51,24 @@ KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] ...@@ -50,6 +51,24 @@ KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT'] MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
paddle.enable_static()
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()
for name, spec in input_spec[0].items():
try:
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
return pruned_input_spec
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = [] preprocess_list = []
...@@ -97,7 +116,7 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -97,7 +116,7 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state = False arch_state = False
from ppdet.core.config.yaml_helpers import setup_orderdict from ppdet.core.config.yaml_helpers import setup_orderdict
setup_orderdict() setup_orderdict()
use_dynamic_shape = True if image_shape[1] == -1 else False use_dynamic_shape = True if image_shape[2] == -1 else False
infer_cfg = OrderedDict({ infer_cfg = OrderedDict({
'mode': 'fluid', 'mode': 'fluid',
'draw_threshold': 0.5, 'draw_threshold': 0.5,
...@@ -141,7 +160,7 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -141,7 +160,7 @@ def _dump_infer_config(config, path, image_shape, model):
dataset_cfg = config['TestDataset'] dataset_cfg = config['TestDataset']
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader( infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape) reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
yaml.dump(infer_cfg, open(path, 'w')) yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(os.path.join(path))) logger.info("Export inference config file to {}".format(os.path.join(path)))
...@@ -41,7 +41,7 @@ import ppdet.utils.stats as stats ...@@ -41,7 +41,7 @@ import ppdet.utils.stats as stats
from ppdet.utils import profiler from ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
from .export_utils import _dump_infer_config from .export_utils import _dump_infer_config, _prune_input_spec
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine') logger = setup_logger('ppdet.engine')
...@@ -541,12 +541,7 @@ class Trainer(object): ...@@ -541,12 +541,7 @@ class Trainer(object):
name, ext = os.path.splitext(image_name) name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext return os.path.join(output_dir, "{}".format(name)) + ext
def export(self, output_dir='output_inference'): def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_shape = None image_shape = None
if self.cfg.architecture in MOT_ARCH: if self.cfg.architecture in MOT_ARCH:
test_reader_name = 'TestMOTReader' test_reader_name = 'TestMOTReader'
...@@ -555,9 +550,11 @@ class Trainer(object): ...@@ -555,9 +550,11 @@ class Trainer(object):
if 'inputs_def' in self.cfg[test_reader_name]: if 'inputs_def' in self.cfg[test_reader_name]:
inputs_def = self.cfg[test_reader_name]['inputs_def'] inputs_def = self.cfg[test_reader_name]['inputs_def']
image_shape = inputs_def.get('image_shape', None) image_shape = inputs_def.get('image_shape', None)
# set image_shape=[3, -1, -1] as default # set image_shape=[None, 3, -1, -1] as default
if image_shape is None: if image_shape is None:
image_shape = [3, -1, -1] image_shape = [None, 3, -1, -1]
if len(image_shape) == 3:
image_shape = [None] + image_shape
if hasattr(self.model, 'deploy'): if hasattr(self.model, 'deploy'):
self.model.deploy = True self.model.deploy = True
...@@ -574,7 +571,7 @@ class Trainer(object): ...@@ -574,7 +571,7 @@ class Trainer(object):
input_spec = [{ input_spec = [{
"image": InputSpec( "image": InputSpec(
shape=[None] + image_shape, name='image'), shape=image_shape, name='image'),
"im_shape": InputSpec( "im_shape": InputSpec(
shape=[None, 2], name='im_shape'), shape=[None, 2], name='im_shape'),
"scale_factor": InputSpec( "scale_factor": InputSpec(
...@@ -585,13 +582,29 @@ class Trainer(object): ...@@ -585,13 +582,29 @@ class Trainer(object):
"crops": InputSpec( "crops": InputSpec(
shape=[None, 3, 192, 64], name='crops') shape=[None, 3, 192, 64], name='crops')
}) })
if prune_input:
static_model = paddle.jit.to_static(self.model, input_spec=input_spec) static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program # NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec # input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec( pruned_input_spec = _prune_input_spec(
input_spec, static_model.forward.main_program, input_spec, static_model.forward.main_program,
static_model.forward.outputs) static_model.forward.outputs)
else:
static_model = None
pruned_input_spec = input_spec
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir)
# dy2st and save model # dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
...@@ -606,22 +619,26 @@ class Trainer(object): ...@@ -606,22 +619,26 @@ class Trainer(object):
input_spec=pruned_input_spec) input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir)) logger.info("Export model and saved in {}".format(save_dir))
def _prune_input_spec(self, input_spec, program, targets): def post_quant(self, output_dir='output_inference'):
# try to prune static program to figure out pruned input spec model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
# so we perform following operations in static mode save_dir = os.path.join(output_dir, model_name)
paddle.enable_static() if not os.path.exists(save_dir):
pruned_input_spec = [{}] os.makedirs(save_dir)
program = program.clone()
program = program._prune(targets=targets) for idx, data in enumerate(self.loader):
global_block = program.global_block() self.model(data)
for name, spec in input_spec[0].items(): if idx == int(self.cfg.get('quant_batch_num', 10)):
try: break
v = global_block.var(name)
pruned_input_spec[0][name] = spec # TODO: support prune input_spec
except Exception: _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
pass save_dir, prune_input=False)
paddle.disable_static()
return pruned_input_spec self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export Post-Quant model and saved in {}".format(save_dir))
def _flops(self, loader): def _flops(self, loader):
self.model.eval() self.model.eval()
......
...@@ -48,6 +48,14 @@ def build_slim_model(cfg, slim_cfg, mode='train'): ...@@ -48,6 +48,14 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
load_pretrain_weight(model, weights) load_pretrain_weight(model, weights)
cfg['model'] = model cfg['model'] = model
cfg['slim_type'] = cfg.slim cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'PTQ':
model = create(cfg.architecture)
load_config(slim_cfg)
load_pretrain_weight(model, cfg.weights)
slim = create(cfg.slim)
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model)
cfg['slim'] = slim
else: else:
load_config(slim_cfg) load_config(slim_cfg)
model = create(cfg.architecture) model = create(cfg.architecture)
......
...@@ -49,3 +49,36 @@ class QAT(object): ...@@ -49,3 +49,36 @@ class QAT(object):
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
self.quanter.save_quantized_model( self.quanter.save_quantized_model(
model=layer, path=path, input_spec=input_spec, **config) model=layer, path=path, input_spec=input_spec, **config)
@register
@serializable
class PTQ(object):
def __init__(self,
ptq_config,
quant_batch_num=10,
output_dir='output_inference',
fuse=True,
fuse_list=None):
super(PTQ, self).__init__()
self.ptq_config = ptq_config
self.quant_batch_num = quant_batch_num
self.output_dir = output_dir
self.fuse = fuse
self.fuse_list = fuse_list
def __call__(self, model):
paddleslim = try_import('paddleslim')
self.ptq = paddleslim.PTQ(**self.ptq_config)
model.eval()
quant_model = self.ptq.quantize(
model, fuse=self.fuse, fuse_list=self.fuse_list)
return quant_model
def save_quantized_model(self,
quant_model,
quantize_model_path,
input_spec=None):
self.ptq.save_quantized_model(quant_model, quantize_model_path,
input_spec)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# ignore warning log
import warnings
warnings.filterwarnings('ignore')
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer
from ppdet.slim import build_slim_model
from ppdet.utils.logger import setup_logger
logger = setup_logger('post_quant')
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--output_dir",
type=str,
default="output_inference",
help="Directory for storing the output model files.")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
# build detector
trainer = Trainer(cfg, mode='eval')
# load weights
if cfg.architecture in ['DeepSORT']:
if cfg.det_weights != 'None':
trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
else:
trainer.load_weights_sde(None, cfg.reid_weights)
else:
trainer.load_weights(cfg.weights)
# post quant model
trainer.post_quant(FLAGS.output_dir)
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
# TODO: to be refined in the future
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
FLAGS.opt['norm_type'] = 'bn'
merge_config(FLAGS.opt)
if FLAGS.slim_config:
cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
# FIXME: Temporarily solve the priority problem of FLAGS.opt
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册