export_model.py 4.0 KB
Newer Older
W
wuzewu 已提交
1
# coding: utf8
W
wuyefeilin 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
W
wuzewu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import pprint
import cv2
import argparse
import numpy as np
import paddle.fluid as fluid

from utils.config import cfg
from models.model_builder import build_model
from models.model_builder import ModelPhase


def parse_args():
    parser = argparse.ArgumentParser(
        description='PaddleSeg Inference Model Exporter')
    parser.add_argument(
        '--cfg',
        dest='cfg_file',
        help='Config file for training (and optionally testing)',
        default=None,
        type=str)
    parser.add_argument(
        'opts',
        help='See utils/config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()

W
wuyefeilin 已提交
52

53 54 55
def export_inference_config():
    deploy_cfg = '''DEPLOY:
        USE_GPU : 1
W
wuyefeilin 已提交
56
        USE_PR : 0
57 58 59 60 61 62 63 64 65 66 67 68 69
        MODEL_PATH : "%s"
        MODEL_FILENAME : "%s"
        PARAMS_FILENAME : "%s"
        EVAL_CROP_SIZE : %s
        MEAN : %s
        STD : %s
        IMAGE_TYPE : "%s"
        NUM_CLASSES : %d
        CHANNELS : %d
        PRE_PROCESSOR : "SegPreProcessor"
        PREDICTOR_MODE : "ANALYSIS"
        BATCH_SIZE : 1
    ''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME,
W
wuyefeilin 已提交
70 71
           cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, cfg.MEAN, cfg.STD,
           cfg.DATASET.IMAGE_TYPE, cfg.DATASET.NUM_CLASSES, len(cfg.STD))
72 73 74 75 76 77 78
    if not os.path.exists(cfg.FREEZE.SAVE_DIR):
        os.mkdir(cfg.FREEZE.SAVE_DIR)
    yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml')
    with open(yaml_path, "w") as fp:
        fp.write(deploy_cfg)
    return yaml_path

W
wuzewu 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

def export_inference_model(args):
    """
    Export PaddlePaddle inference model for prediction depolyment and serving.
    """
    print("Exporting inference model...")
    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    image, logit_out = build_model(
        infer_prog, startup_prog, phase=ModelPhase.PREDICT)

    # Use CPU for exporting inference model instead of GPU
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(startup_prog)
    infer_prog = infer_prog.clone(for_test=True)

    if os.path.exists(cfg.TEST.TEST_MODEL):
W
wuyefeilin 已提交
97 98 99 100 101 102 103
        print('load test model:', cfg.TEST.TEST_MODEL)
        try:
            fluid.load(infer_prog, os.path.join(cfg.TEST.TEST_MODEL, 'model'),
                       exe)
        except:
            fluid.io.load_params(
                exe, cfg.TEST.TEST_MODEL, main_program=infer_prog)
W
wuzewu 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116
    else:
        print("TEST.TEST_MODEL diretory is empty!")
        exit(-1)

    fluid.io.save_inference_model(
        cfg.FREEZE.SAVE_DIR,
        feeded_var_names=[image.name],
        target_vars=[logit_out],
        executor=exe,
        main_program=infer_prog,
        model_filename=cfg.FREEZE.MODEL_FILENAME,
        params_filename=cfg.FREEZE.PARAMS_FILENAME)
    print("Inference model exported!")
117 118 119
    print("Exporting inference model config...")
    deploy_cfg_path = export_inference_config()
    print("Inference model saved : [%s]" % (deploy_cfg_path))
W
wuzewu 已提交
120 121 122 123 124 125


def main():
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
126
    if args.opts:
W
wuzewu 已提交
127 128 129 130 131 132 133 134
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    export_inference_model(args)


if __name__ == '__main__':
    main()