未验证 提交 1dbbce6c 编写于 作者: L Liufang Sang 提交者: GitHub

Fix quant model export issues(#198)

* save inference model

* fix details

* add export model in quantization

* remove useless code

* remove useless code

* change fluid.layers.data to fluid.data
上级 4e361b18
......@@ -175,12 +175,16 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
# 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
# 预测部署时只须对输入图像增加batch_size维度即可
if ModelPhase.is_predict(phase):
origin_image = fluid.data(
name='image',
shape=[-1, -1, -1, cfg.DATASET.DATA_DIM],
dtype='float32')
image, valid_shape, origin_shape = export_preprocess(
origin_image)
if cfg.SLIM.PREPROCESS:
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
else:
origin_image = fluid.data(
name='image',
shape=[-1, -1, -1, cfg.DATASET.DATA_DIM],
dtype='float32')
image, valid_shape, origin_shape = export_preprocess(
origin_image)
else:
image = fluid.data(
......@@ -271,15 +275,19 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
logit = softmax(logit)
# 获取有效部分
logit = fluid.layers.slice(
logit, axes=[2, 3], starts=[0, 0], ends=valid_shape)
logit = fluid.layers.resize_bilinear(
logit,
out_shape=origin_shape,
align_corners=False,
align_mode=0)
logit = fluid.layers.argmax(logit, axis=1)
if cfg.SLIM.PREPROCESS:
return image, logit
else:
logit = fluid.layers.slice(
logit, axes=[2, 3], starts=[0, 0], ends=valid_shape)
logit = fluid.layers.resize_bilinear(
logit,
out_shape=origin_shape,
align_corners=False,
align_mode=0)
logit = fluid.layers.argmax(logit, axis=1)
return origin_image, logit
if class_num == 1:
......
......@@ -155,10 +155,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1
cfg.SOLVER.NUM_EPOCHS = 30
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"]
# 是否开启warmup学习策略
cfg.SOLVER.LR_WARMUP = False
# 是否开启warmup学习策略
cfg.SOLVER.LR_WARMUP = False
# warmup的迭代次数
cfg.SOLVER.LR_WARMUP_STEPS = 2000
cfg.SOLVER.LR_WARMUP_STEPS = 2000
# cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目,
# 动态调整类别权重。
# 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]
......@@ -228,7 +228,6 @@ cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160]
cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3
cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320]
########################## 预测部署模型配置 ###################################
# 预测保存的模型名称
cfg.FREEZE.MODEL_FILENAME = '__model__'
......@@ -251,4 +250,4 @@ cfg.SLIM.NAS_SPACE_NAME = ""
cfg.SLIM.PRUNE_PARAMS = ''
cfg.SLIM.PRUNE_RATIOS = []
cfg.SLIM.PREPROCESS = False
......@@ -133,7 +133,20 @@ TRAIN.SYNC_BATCH_NORM False \
BATCH_SIZE 16 \
```
## 导出模型
使用脚本[slim/quantization/export_model.py](./export_model.py)导出模型。
导出命令:
分割库根目录下运行
```
python -u ./slim/quantization/export_model.py --not_quant_pattern last_conv --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml \
TEST.TEST_MODEL "./snapshots/mobilenetv2_quant/best_model" \
MODEL.DEEPLAB.ENCODER_WITH_ASPP False \
MODEL.DEEPLAB.ENABLE_DECODER False \
TRAIN.SYNC_BATCH_NORM False \
SLIM.PREPROCESS True \
```
## 量化结果
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import pprint
import cv2
import argparse
import numpy as np
import paddle.fluid as fluid
from utils.config import cfg
from models.model_builder import build_model
from models.model_builder import ModelPhase
from paddleslim.quant import quant_aware, convert
def parse_args():
parser = argparse.ArgumentParser(
description='PaddleSeg Inference Model Exporter')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str)
parser.add_argument(
"--not_quant_pattern",
nargs='+',
type=str,
help=
"Layers which name_scope contains string in not_quant_pattern will not be quantized"
)
parser.add_argument(
'opts',
help='See utils/config.py for all options',
default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def export_inference_config():
deploy_cfg = '''DEPLOY:
USE_GPU : 1
MODEL_PATH : "%s"
MODEL_FILENAME : "%s"
PARAMS_FILENAME : "%s"
EVAL_CROP_SIZE : %s
MEAN : %s
STD : %s
IMAGE_TYPE : "%s"
NUM_CLASSES : %d
CHANNELS : %d
PRE_PROCESSOR : "SegPreProcessor"
PREDICTOR_MODE : "ANALYSIS"
BATCH_SIZE : 1
''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME,
cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, cfg.MEAN, cfg.STD,
cfg.DATASET.IMAGE_TYPE, cfg.DATASET.NUM_CLASSES, len(cfg.STD))
if not os.path.exists(cfg.FREEZE.SAVE_DIR):
os.mkdir(cfg.FREEZE.SAVE_DIR)
yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml')
with open(yaml_path, "w") as fp:
fp.write(deploy_cfg)
return yaml_path
def export_inference_model(args):
"""
Export PaddlePaddle inference model for prediction depolyment and serving.
"""
print("Exporting inference model...")
startup_prog = fluid.Program()
infer_prog = fluid.Program()
image, logit_out = build_model(
infer_prog, startup_prog, phase=ModelPhase.PREDICT)
# Use CPU for exporting inference model instead of GPU
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
infer_prog = infer_prog.clone(for_test=True)
not_quant_pattern_list = []
if args.not_quant_pattern is not None:
not_quant_pattern_list = args.not_quant_pattern
config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'],
'not_quant_pattern': not_quant_pattern_list
}
infer_prog = quant_aware(infer_prog, place, config, for_test=True)
if os.path.exists(cfg.TEST.TEST_MODEL):
fluid.io.load_persistables(
exe, cfg.TEST.TEST_MODEL, main_program=infer_prog)
else:
print("TEST.TEST_MODEL diretory is empty!")
exit(-1)
infer_prog = convert(infer_prog, place, config)
fluid.io.save_inference_model(
cfg.FREEZE.SAVE_DIR,
feeded_var_names=[image.name],
target_vars=[logit_out],
executor=exe,
main_program=infer_prog,
model_filename=cfg.FREEZE.MODEL_FILENAME,
params_filename=cfg.FREEZE.PARAMS_FILENAME)
print("Inference model exported!")
print("Exporting inference model config...")
deploy_cfg_path = export_inference_config()
print("Inference model saved : [%s]" % (deploy_cfg_path))
def main():
args = parse_args()
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
if args.opts:
cfg.update_from_list(args.opts)
cfg.check_and_infer()
print(pprint.pformat(cfg))
export_inference_model(args)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册