predict.py 5.8 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import logging
import argparse
20
import ast
21 22 23 24 25 26 27
import numpy as np
try:
    import cPickle as pickle
except:
    import pickle
import paddle.fluid as fluid

28
from utils.config_utils import *
29
import models
30 31 32
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
33
from utils.utility import check_version
34

D
dengkaipeng 已提交
35
logging.root.handlers = []
36 37 38 39 40 41 42 43
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
S
SunGaofeng 已提交
44
        '--model_name',
45 46 47 48 49 50 51 52 53
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
54 55 56 57
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
58 59 60 61
    parser.add_argument(
        '--weights',
        type=str,
        default=None,
62 63
        help='weight path, None to automatically download weights provided by Paddle.'
    )
64
    parser.add_argument(
S
SunGaofeng 已提交
65
        '--batch_size',
66 67 68 69 70 71 72 73 74
        type=int,
        default=1,
        help='sample number in a batch for inference.')
    parser.add_argument(
        '--filelist',
        type=str,
        default=None,
        help='path to inferenece data file lists file.')
    parser.add_argument(
S
SunGaofeng 已提交
75
        '--log_interval',
76 77 78 79
        type=int,
        default=1,
        help='mini-batch interval to log.')
    parser.add_argument(
S
SunGaofeng 已提交
80
        '--infer_topk',
81 82 83 84
        type=int,
        default=20,
        help='topk predictions to restore.')
    parser.add_argument(
85 86 87 88 89 90 91 92 93
        '--save_dir',
        type=str,
        default=os.path.join('data', 'predict_results'),
        help='directory to store results')
    parser.add_argument(
        '--video_path',
        type=str,
        default=None,
        help='directory to store results')
94 95 96 97
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
98 99 100 101
def infer(args):
    # parse config
    config = parse_config(args.config)
    infer_config = merge_configs(config, 'infer', vars(args))
D
dengkaipeng 已提交
102
    print_configs(infer_config, "Infer")
103
    infer_model = models.get_model(args.model_name, infer_config, mode='infer')
104
    infer_model.build_input(use_dataloader=False)
105 106 107 108 109 110 111
    infer_model.build_model()
    infer_feeds = infer_model.feeds()
    infer_outputs = infer_model.outputs()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

H
huangjun12 已提交
112 113
    exe.run(fluid.default_startup_program())

114
    filelist = args.filelist or infer_config.INFER.filelist
115 116 117 118 119
    filepath = args.video_path or infer_config.INFER.get('filepath', '')
    if filepath != '':
        assert os.path.exists(filepath), "{} not exist.".format(filepath)
    else:
        assert os.path.exists(filelist), "{} not exist.".format(filelist)
120

121
    # get infer reader
122
    infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
123 124 125 126 127 128 129

    if args.weights:
        assert os.path.exists(
            args.weights), "Given weight dir {} not exist.".format(args.weights)
    # if no weight files specified, download weights from paddle
    weights = args.weights or infer_model.get_weights()

S
SunGaofeng 已提交
130 131
    infer_model.load_test_weights(exe, weights,
                                  fluid.default_main_program(), place)
132 133

    infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
134 135 136 137
    fetch_list = infer_model.fetches()

    infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)
    infer_metrics.reset()
138

D
dengkaipeng 已提交
139 140 141
    periods = []
    cur_time = time.time()
    for infer_iter, data in enumerate(infer_reader()):
H
huangjun12 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        if args.model_name == 'ETS':
            data_feed_in = [items[:3] for items in data]
            vinfo = [items[3:] for items in data]
            video_id = [items[0] for items in vinfo]
            infer_outs = exe.run(fetch_list=fetch_list,
                                 feed=infer_feeder.feed(data_feed_in),
                                 return_numpy=False)
            infer_result_list = infer_outs + vinfo
        else:
            data_feed_in = [items[:-1] for items in data]
            video_id = [items[-1] for items in data]
            infer_outs = exe.run(fetch_list=fetch_list,
                                 feed=infer_feeder.feed(data_feed_in))
            infer_result_list = [item for item in infer_outs] + [video_id]

D
dengkaipeng 已提交
157
        prev_time = cur_time
158
        cur_time = time.time()
D
dengkaipeng 已提交
159 160
        period = cur_time - prev_time
        periods.append(period)
161 162 163

        infer_metrics.accumulate(infer_result_list)

164
        if args.log_interval > 0 and infer_iter % args.log_interval == 0:
S
SunGaofeng 已提交
165 166
            logger.info('Processed {} samples'.format((infer_iter + 1) * len(
                video_id)))
D
dengkaipeng 已提交
167 168 169 170

    logger.info('[INFER] infer finished. average time: {}'.format(
        np.mean(periods)))

171
    if not os.path.isdir(args.save_dir):
172 173 174
        os.makedirs(args.save_dir)

    infer_metrics.finalize_and_log_out(savedir=args.save_dir)
175

S
SunGaofeng 已提交
176

177 178
if __name__ == "__main__":
    args = parse_args()
179 180
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
181
    check_version()
182 183
    logger.info(args)

D
dengkaipeng 已提交
184
    infer(args)