inference_model.py 3.7 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import logging
import argparse
20
import ast
21 22 23 24 25 26 27
import numpy as np
try:
    import cPickle as pickle
except:
    import pickle
import paddle.fluid as fluid

28
from utils.config_utils import *
29
import models
30 31 32
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
33

D
dengkaipeng 已提交
34
logging.root.handlers = []
35 36 37 38 39 40 41 42
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
S
SunGaofeng 已提交
43
        '--model_name',
44 45 46 47 48 49 50 51 52
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
53 54 55 56
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
57 58 59 60
    parser.add_argument(
        '--weights',
        type=str,
        default=None,
61 62
        help='weight path, None to automatically download weights provided by Paddle.'
    )
63
    parser.add_argument(
S
SunGaofeng 已提交
64
        '--batch_size',
65 66 67 68
        type=int,
        default=1,
        help='sample number in a batch for inference.')
    parser.add_argument(
69
        '--save_dir',
70
        type=str,
71 72
        default='./',
        help='directory to store model and params file')
73 74 75 76
    args = parser.parse_args()
    return args


77
def save_inference_model(args):
D
dengkaipeng 已提交
78 79 80
    # parse config
    config = parse_config(args.config)
    infer_config = merge_configs(config, 'infer', vars(args))
D
dengkaipeng 已提交
81
    print_configs(infer_config, "Infer")
82
    infer_model = models.get_model(args.model_name, infer_config, mode='infer')
83 84 85 86 87 88 89
    infer_model.build_input(use_pyreader=False)
    infer_model.build_model()
    infer_feeds = infer_model.feeds()
    infer_outputs = infer_model.outputs()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
90
    exe.run(fluid.default_startup_program())
91 92 93 94 95 96 97

    if args.weights:
        assert os.path.exists(
            args.weights), "Given weight dir {} not exist.".format(args.weights)
    # if no weight files specified, download weights from paddle
    weights = args.weights or infer_model.get_weights()

S
SunGaofeng 已提交
98 99
    infer_model.load_test_weights(exe, weights,
                                  fluid.default_main_program(), place)
100 101

    if not os.path.isdir(args.save_dir):
102 103 104 105 106 107 108 109 110 111 112 113 114
        os.makedirs(args.save_dir)

    # saving inference model
    fluid.io.save_inference_model(
        args.save_dir,
        feeded_var_names=[item.name for item in infer_feeds],
        target_vars=infer_outputs,
        executor=exe,
        main_program=fluid.default_main_program(),
        model_filename=args.model_name + "_model.pdmodel",
        params_filename=args.model_name + "_params.pdparams")

    print("save inference model at %s" % (args.save_dir))
115

S
SunGaofeng 已提交
116

117 118
if __name__ == "__main__":
    args = parse_args()
119 120
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
121 122
    logger.info(args)

123
    save_inference_model(args)