predict.py 5.3 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import logging
import argparse
20
import ast
21 22 23 24 25 26 27
import numpy as np
try:
    import cPickle as pickle
except:
    import pickle
import paddle.fluid as fluid

28
from utils.config_utils import *
29
import models
30 31 32
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
33

D
dengkaipeng 已提交
34
logging.root.handlers = []
35 36 37 38 39 40 41 42
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
S
SunGaofeng 已提交
43
        '--model_name',
44 45 46 47 48 49 50 51 52
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
53 54 55 56
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
57 58 59 60
    parser.add_argument(
        '--weights',
        type=str,
        default=None,
61 62
        help='weight path, None to automatically download weights provided by Paddle.'
    )
63
    parser.add_argument(
S
SunGaofeng 已提交
64
        '--batch_size',
65 66 67 68 69 70 71 72 73
        type=int,
        default=1,
        help='sample number in a batch for inference.')
    parser.add_argument(
        '--filelist',
        type=str,
        default=None,
        help='path to inferenece data file lists file.')
    parser.add_argument(
S
SunGaofeng 已提交
74
        '--log_interval',
75 76 77 78
        type=int,
        default=1,
        help='mini-batch interval to log.')
    parser.add_argument(
S
SunGaofeng 已提交
79
        '--infer_topk',
80 81 82 83
        type=int,
        default=20,
        help='topk predictions to restore.')
    parser.add_argument(
84 85 86 87 88 89 90 91 92
        '--save_dir',
        type=str,
        default=os.path.join('data', 'predict_results'),
        help='directory to store results')
    parser.add_argument(
        '--video_path',
        type=str,
        default=None,
        help='directory to store results')
93 94 95 96
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
97 98 99 100
def infer(args):
    # parse config
    config = parse_config(args.config)
    infer_config = merge_configs(config, 'infer', vars(args))
D
dengkaipeng 已提交
101
    print_configs(infer_config, "Infer")
102
    infer_model = models.get_model(args.model_name, infer_config, mode='infer')
103 104 105 106 107 108 109 110
    infer_model.build_input(use_pyreader=False)
    infer_model.build_model()
    infer_feeds = infer_model.feeds()
    infer_outputs = infer_model.outputs()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

111
    filelist = args.filelist or infer_config.INFER.filelist
112 113 114 115 116
    filepath = args.video_path or infer_config.INFER.get('filepath', '')
    if filepath != '':
        assert os.path.exists(filepath), "{} not exist.".format(filepath)
    else:
        assert os.path.exists(filelist), "{} not exist.".format(filelist)
117

118
    # get infer reader
119
    infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
120 121 122 123 124 125 126

    if args.weights:
        assert os.path.exists(
            args.weights), "Given weight dir {} not exist.".format(args.weights)
    # if no weight files specified, download weights from paddle
    weights = args.weights or infer_model.get_weights()

S
SunGaofeng 已提交
127 128
    infer_model.load_test_weights(exe, weights,
                                  fluid.default_main_program(), place)
129 130

    infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
131 132 133 134
    fetch_list = infer_model.fetches()

    infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)
    infer_metrics.reset()
135

D
dengkaipeng 已提交
136 137 138 139 140 141 142 143
    periods = []
    cur_time = time.time()
    for infer_iter, data in enumerate(infer_reader()):
        data_feed_in = [items[:-1] for items in data]
        video_id = [items[-1] for items in data]
        infer_outs = exe.run(fetch_list=fetch_list,
                             feed=infer_feeder.feed(data_feed_in))
        prev_time = cur_time
144
        cur_time = time.time()
D
dengkaipeng 已提交
145 146
        period = cur_time - prev_time
        periods.append(period)
147 148 149 150

        infer_result_list = [item for item in infer_outs] + [video_id]
        infer_metrics.accumulate(infer_result_list)

151
        if args.log_interval > 0 and infer_iter % args.log_interval == 0:
S
SunGaofeng 已提交
152 153
            logger.info('Processed {} samples'.format((infer_iter + 1) * len(
                video_id)))
D
dengkaipeng 已提交
154 155 156 157

    logger.info('[INFER] infer finished. average time: {}'.format(
        np.mean(periods)))

158
    if not os.path.isdir(args.save_dir):
159 160 161
        os.makedirs(args.save_dir)

    infer_metrics.finalize_and_log_out(savedir=args.save_dir)
162

S
SunGaofeng 已提交
163

164 165
if __name__ == "__main__":
    args = parse_args()
166 167
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
168 169
    logger.info(args)

D
dengkaipeng 已提交
170
    infer(args)