From 1dbbce6c2557a3244eac9f5f123874f5dbf1b756 Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Mon, 13 Apr 2020 03:11:17 -0500 Subject: [PATCH] Fix quant model export issues(#198) * save inference model * fix details * add export model in quantization * remove useless code * remove useless code * change fluid.layers.data to fluid.data --- pdseg/models/model_builder.py | 38 +++++--- pdseg/utils/config.py | 9 +- slim/quantization/README.md | 13 +++ slim/quantization/export_model.py | 149 ++++++++++++++++++++++++++++++ 4 files changed, 189 insertions(+), 20 deletions(-) create mode 100644 slim/quantization/export_model.py diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py index 4092a417..a81f2ff4 100644 --- a/pdseg/models/model_builder.py +++ b/pdseg/models/model_builder.py @@ -175,12 +175,16 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程 # 预测部署时只须对输入图像增加batch_size维度即可 if ModelPhase.is_predict(phase): - origin_image = fluid.data( - name='image', - shape=[-1, -1, -1, cfg.DATASET.DATA_DIM], - dtype='float32') - image, valid_shape, origin_shape = export_preprocess( - origin_image) + if cfg.SLIM.PREPROCESS: + image = fluid.data( + name='image', shape=image_shape, dtype='float32') + else: + origin_image = fluid.data( + name='image', + shape=[-1, -1, -1, cfg.DATASET.DATA_DIM], + dtype='float32') + image, valid_shape, origin_shape = export_preprocess( + origin_image) else: image = fluid.data( @@ -271,15 +275,19 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): logit = softmax(logit) # 获取有效部分 - logit = fluid.layers.slice( - logit, axes=[2, 3], starts=[0, 0], ends=valid_shape) - - logit = fluid.layers.resize_bilinear( - logit, - out_shape=origin_shape, - align_corners=False, - align_mode=0) - logit = fluid.layers.argmax(logit, axis=1) + if cfg.SLIM.PREPROCESS: + return image, logit + + else: + logit = fluid.layers.slice( + logit, axes=[2, 3], starts=[0, 0], ends=valid_shape) + + logit = fluid.layers.resize_bilinear( + logit, + out_shape=origin_shape, + align_corners=False, + align_mode=0) + logit = fluid.layers.argmax(logit, axis=1) return origin_image, logit if class_num == 1: diff --git a/pdseg/utils/config.py b/pdseg/utils/config.py index c3d84216..220b9e5e 100644 --- a/pdseg/utils/config.py +++ b/pdseg/utils/config.py @@ -155,10 +155,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1 cfg.SOLVER.NUM_EPOCHS = 30 # loss的选择,支持softmax_loss, bce_loss, dice_loss cfg.SOLVER.LOSS = ["softmax_loss"] -# 是否开启warmup学习策略 -cfg.SOLVER.LR_WARMUP = False +# 是否开启warmup学习策略 +cfg.SOLVER.LR_WARMUP = False # warmup的迭代次数 -cfg.SOLVER.LR_WARMUP_STEPS = 2000 +cfg.SOLVER.LR_WARMUP_STEPS = 2000 # cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目, # 动态调整类别权重。 # 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9] @@ -228,7 +228,6 @@ cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160] cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3 cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320] - ########################## 预测部署模型配置 ################################### # 预测保存的模型名称 cfg.FREEZE.MODEL_FILENAME = '__model__' @@ -251,4 +250,4 @@ cfg.SLIM.NAS_SPACE_NAME = "" cfg.SLIM.PRUNE_PARAMS = '' cfg.SLIM.PRUNE_RATIOS = [] - +cfg.SLIM.PREPROCESS = False diff --git a/slim/quantization/README.md b/slim/quantization/README.md index 9af04033..28a74e01 100644 --- a/slim/quantization/README.md +++ b/slim/quantization/README.md @@ -133,7 +133,20 @@ TRAIN.SYNC_BATCH_NORM False \ BATCH_SIZE 16 \ ``` +## 导出模型 +使用脚本[slim/quantization/export_model.py](./export_model.py)导出模型。 +导出命令: + +分割库根目录下运行 +``` +python -u ./slim/quantization/export_model.py --not_quant_pattern last_conv --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml \ +TEST.TEST_MODEL "./snapshots/mobilenetv2_quant/best_model" \ +MODEL.DEEPLAB.ENCODER_WITH_ASPP False \ +MODEL.DEEPLAB.ENABLE_DECODER False \ +TRAIN.SYNC_BATCH_NORM False \ +SLIM.PREPROCESS True \ +``` ## 量化结果 diff --git a/slim/quantization/export_model.py b/slim/quantization/export_model.py new file mode 100644 index 00000000..3891254b --- /dev/null +++ b/slim/quantization/export_model.py @@ -0,0 +1,149 @@ +# coding: utf8 +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +import time +import pprint +import cv2 +import argparse +import numpy as np +import paddle.fluid as fluid + +from utils.config import cfg +from models.model_builder import build_model +from models.model_builder import ModelPhase +from paddleslim.quant import quant_aware, convert + + +def parse_args(): + parser = argparse.ArgumentParser( + description='PaddleSeg Inference Model Exporter') + parser.add_argument( + '--cfg', + dest='cfg_file', + help='Config file for training (and optionally testing)', + default=None, + type=str) + parser.add_argument( + "--not_quant_pattern", + nargs='+', + type=str, + help= + "Layers which name_scope contains string in not_quant_pattern will not be quantized" + ) + parser.add_argument( + 'opts', + help='See utils/config.py for all options', + default=None, + nargs=argparse.REMAINDER) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def export_inference_config(): + deploy_cfg = '''DEPLOY: + USE_GPU : 1 + MODEL_PATH : "%s" + MODEL_FILENAME : "%s" + PARAMS_FILENAME : "%s" + EVAL_CROP_SIZE : %s + MEAN : %s + STD : %s + IMAGE_TYPE : "%s" + NUM_CLASSES : %d + CHANNELS : %d + PRE_PROCESSOR : "SegPreProcessor" + PREDICTOR_MODE : "ANALYSIS" + BATCH_SIZE : 1 + ''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME, + cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, cfg.MEAN, cfg.STD, + cfg.DATASET.IMAGE_TYPE, cfg.DATASET.NUM_CLASSES, len(cfg.STD)) + if not os.path.exists(cfg.FREEZE.SAVE_DIR): + os.mkdir(cfg.FREEZE.SAVE_DIR) + yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml') + with open(yaml_path, "w") as fp: + fp.write(deploy_cfg) + return yaml_path + + +def export_inference_model(args): + """ + Export PaddlePaddle inference model for prediction depolyment and serving. + """ + print("Exporting inference model...") + startup_prog = fluid.Program() + infer_prog = fluid.Program() + image, logit_out = build_model( + infer_prog, startup_prog, phase=ModelPhase.PREDICT) + + # Use CPU for exporting inference model instead of GPU + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + infer_prog = infer_prog.clone(for_test=True) + not_quant_pattern_list = [] + if args.not_quant_pattern is not None: + not_quant_pattern_list = args.not_quant_pattern + + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'not_quant_pattern': not_quant_pattern_list + } + + infer_prog = quant_aware(infer_prog, place, config, for_test=True) + if os.path.exists(cfg.TEST.TEST_MODEL): + fluid.io.load_persistables( + exe, cfg.TEST.TEST_MODEL, main_program=infer_prog) + else: + print("TEST.TEST_MODEL diretory is empty!") + exit(-1) + + infer_prog = convert(infer_prog, place, config) + + fluid.io.save_inference_model( + cfg.FREEZE.SAVE_DIR, + feeded_var_names=[image.name], + target_vars=[logit_out], + executor=exe, + main_program=infer_prog, + model_filename=cfg.FREEZE.MODEL_FILENAME, + params_filename=cfg.FREEZE.PARAMS_FILENAME) + print("Inference model exported!") + print("Exporting inference model config...") + deploy_cfg_path = export_inference_config() + print("Inference model saved : [%s]" % (deploy_cfg_path)) + + +def main(): + args = parse_args() + if args.cfg_file is not None: + cfg.update_from_file(args.cfg_file) + if args.opts: + cfg.update_from_list(args.opts) + cfg.check_and_infer() + print(pprint.pformat(cfg)) + export_inference_model(args) + + +if __name__ == '__main__': + main() -- GitLab