eval.py 5.3 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import logging
import argparse
20
import ast
21 22
import numpy as np
import paddle.fluid as fluid
D
dengkaipeng 已提交
23

24
from utils.config_utils import *
25
import models
26
from reader import get_reader
27
from metrics import get_metrics
28
from utils.utility import check_cuda
29
from utils.utility import check_version
30

D
dengkaipeng 已提交
31
logging.root.handlers = []
32 33 34 35 36 37 38 39
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
S
SunGaofeng 已提交
40
        '--model_name',
41 42 43 44 45 46 47 48 49
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
50
        '--batch_size',
51 52
        type=int,
        default=None,
S
SunGaofeng 已提交
53
        help='test batch size. None to use config file setting.')
54
    parser.add_argument(
55 56 57 58
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
59 60 61 62
    parser.add_argument(
        '--weights',
        type=str,
        default=None,
63 64 65 66 67 68 69
        help='weight path, None to automatically download weights provided by Paddle.'
    )
    parser.add_argument(
        '--save_dir',
        type=str,
        default=os.path.join('data', 'evaluate_results'),
        help='output dir path, default to use ./data/evaluate_results')
70
    parser.add_argument(
S
SunGaofeng 已提交
71
        '--log_interval',
72 73 74 75 76 77 78
        type=int,
        default=1,
        help='mini-batch interval to log.')
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
79 80 81 82
def test(args):
    # parse config
    config = parse_config(args.config)
    test_config = merge_configs(config, 'test', vars(args))
D
dengkaipeng 已提交
83
    print_configs(test_config, "Test")
84
    use_dali = test_config['TEST'].get('use_dali', False)
D
dengkaipeng 已提交
85 86

    # build model
87
    test_model = models.get_model(args.model_name, test_config, mode='test')
88
    test_model.build_input(use_dataloader=False)
89 90
    test_model.build_model()
    test_feeds = test_model.feeds()
91
    test_fetch_list = test_model.fetches()
92 93 94 95

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

96 97
    exe.run(fluid.default_startup_program())

98 99 100 101 102
    if args.weights:
        assert os.path.exists(
            args.weights), "Given weight dir {} not exist.".format(args.weights)
    weights = args.weights or test_model.get_weights()

103 104
    logger.info('load test weights from {}'.format(weights))

S
SunGaofeng 已提交
105 106
    test_model.load_test_weights(exe, weights,
                                 fluid.default_main_program(), place)
107

D
dengkaipeng 已提交
108
    # get reader and metrics
109 110
    test_reader = get_reader(args.model_name.upper(), 'test', test_config)
    test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
D
dengkaipeng 已提交
111

112 113
    test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)

D
dengkaipeng 已提交
114 115 116
    epoch_period = []
    for test_iter, data in enumerate(test_reader()):
        cur_time = time.time()
117 118 119 120 121 122 123 124 125 126 127 128 129 130
        if args.model_name == 'ETS':
            feat_data = [items[:3] for items in data]
            vinfo = [items[3:] for items in data]
            test_outs = exe.run(fetch_list=test_fetch_list,
                                feed=test_feeder.feed(feat_data),
                                return_numpy=False)
            test_outs += [vinfo]
        elif args.model_name == 'TALL':
            feat_data = [items[:2] for items in data]
            vinfo = [items[2:] for items in data]
            test_outs = exe.run(fetch_list=test_fetch_list,
                                feed=test_feeder.feed(feat_data),
                                return_numpy=True)
            test_outs += [vinfo]
131 132 133 134
        elif args.model_name == 'TSN' and use_dali:
            test_outs = exe.run(fetch_list=test_fetch_list,
                                feed={'image': data[0],
                                      'label': data[1]})
135 136 137
        else:
            test_outs = exe.run(fetch_list=test_fetch_list,
                                feed=test_feeder.feed(data))
D
dengkaipeng 已提交
138 139
        period = time.time() - cur_time
        epoch_period.append(period)
140
        test_metrics.accumulate(test_outs)
D
dengkaipeng 已提交
141 142 143 144

        # metric here
        if args.log_interval > 0 and test_iter % args.log_interval == 0:
            info_str = '[EVAL] Batch {}'.format(test_iter)
145 146 147 148 149
            test_metrics.calculate_and_log_out(test_outs, info_str)

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    test_metrics.finalize_and_log_out("[EVAL] eval finished. ", args.save_dir)
150 151 152 153


if __name__ == "__main__":
    args = parse_args()
154 155
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
156
    check_version()
D
dengkaipeng 已提交
157
    logger.info(args)
158

D
dengkaipeng 已提交
159
    test(args)