test.py 4.9 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import logging
import argparse
import numpy as np
import paddle.fluid as fluid
D
dengkaipeng 已提交
22 23

from config import *
24
import models
25 26
from datareader import get_reader
from metrics import get_metrics
27

D
dengkaipeng 已提交
28
logging.root.handlers = []
29 30 31 32 33 34 35 36
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
S
SunGaofeng 已提交
37
        '--model_name',
38 39 40 41 42 43 44 45 46
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
47
        '--batch_size',
48 49
        type=int,
        default=None,
S
SunGaofeng 已提交
50
        help='test batch size. None to use config file setting.')
51
    parser.add_argument(
S
SunGaofeng 已提交
52
        '--use_gpu', type=bool, default=True, help='default use gpu.')
53 54 55 56 57 58
    parser.add_argument(
        '--weights',
        type=str,
        default=None,
        help='weight path, None to use weights from Paddle.')
    parser.add_argument(
S
SunGaofeng 已提交
59
        '--log_interval',
60 61 62 63 64 65 66
        type=int,
        default=1,
        help='mini-batch interval to log.')
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
67 68 69 70
def test(args):
    # parse config
    config = parse_config(args.config)
    test_config = merge_configs(config, 'test', vars(args))
D
dengkaipeng 已提交
71
    print_configs(test_config, "Test")
D
dengkaipeng 已提交
72 73

    # build model
74
    test_model = models.get_model(args.model_name, test_config, mode='test')
75 76 77 78
    test_model.build_input(use_pyreader=False)
    test_model.build_model()
    test_feeds = test_model.feeds()
    test_outputs = test_model.outputs()
S
SunGaofeng 已提交
79
    test_loss = test_model.loss()
80 81 82 83 84 85 86 87 88

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    if args.weights:
        assert os.path.exists(
            args.weights), "Given weight dir {} not exist.".format(args.weights)
    weights = args.weights or test_model.get_weights()

S
SunGaofeng 已提交
89 90
    test_model.load_test_weights(exe, weights,
                                 fluid.default_main_program(), place)
91

D
dengkaipeng 已提交
92
    # get reader and metrics
93 94
    test_reader = get_reader(args.model_name.upper(), 'test', test_config)
    test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
D
dengkaipeng 已提交
95

96
    test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
S
SunGaofeng 已提交
97 98 99 100
    if args.model_name.upper() in ['CTCN']:
        fetch_list = [x.name for x in test_loss] + \
                     [x.name for x in test_outputs] + \
                     [test_feeds[-1].name]
S
SunGaofeng 已提交
101
    else:
S
SunGaofeng 已提交
102 103 104 105 106
        if test_loss is None:
            fetch_list = [x.name for x in test_outputs] + [test_feeds[-1].name]
        else:
            fetch_list = [test_loss.name] + [x.name for x in test_outputs
                                             ] + [test_feeds[-1].name]
107

D
dengkaipeng 已提交
108 109 110
    epoch_period = []
    for test_iter, data in enumerate(test_reader()):
        cur_time = time.time()
S
SunGaofeng 已提交
111
        test_outs = exe.run(fetch_list=fetch_list, feed=test_feeder.feed(data))
D
dengkaipeng 已提交
112 113
        period = time.time() - cur_time
        epoch_period.append(period)
S
SunGaofeng 已提交
114 115 116 117 118 119 120 121 122 123
        if args.model_name.upper() in ['CTCN']:
            total_loss = test_outs[0]
            loc_loss = test_outs[1]
            cls_loss = test_outs[2]
            loc_preds = test_outs[3]
            cls_preds = test_outs[4]
            fid = test_outs[-1]
            loss = [total_loss, loc_loss, cls_loss]
            pred = [loc_preds, cls_preds]
            label = fid
S
SunGaofeng 已提交
124
        else:
S
SunGaofeng 已提交
125 126 127 128 129 130 131 132
            if test_loss is None:
                loss = np.zeros(1, ).astype('float32')
                pred = np.array(test_outs[0])
                label = np.array(test_outs[-1])
            else:
                loss = np.array(test_outs[0])
                pred = np.array(test_outs[1])
                label = np.array(test_outs[-1])
D
dengkaipeng 已提交
133 134 135 136 137 138 139
        test_metrics.accumulate(loss, pred, label)

        # metric here
        if args.log_interval > 0 and test_iter % args.log_interval == 0:
            info_str = '[EVAL] Batch {}'.format(test_iter)
            test_metrics.calculate_and_log_out(loss, pred, label, info_str)
    test_metrics.finalize_and_log_out("[EVAL] eval finished. ")
140 141 142 143


if __name__ == "__main__":
    args = parse_args()
D
dengkaipeng 已提交
144
    logger.info(args)
145

D
dengkaipeng 已提交
146
    test(args)