eval.py 5.1 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Q
qingqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import argparse
import functools

Q
qingqing01 已提交
19
import paddle
Q
qingqing01 已提交
20

Q
qingqing01 已提交
21 22
from paddle.static import InputSpec as Input
from paddle.vision.transforms import BatchCompose
Q
qingqing01 已提交
23 24

from utility import add_arguments, print_arguments
25
from utility import SeqAccuracy, LoggerCallBack, SeqBeamAccuracy
Q
qingqing01 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
from utility import postprocess
from seq2seq_attn import Seq2SeqAttModel, Seq2SeqAttInferModel, WeightCrossEntropy
import data

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   32,                 "Minibatch size.")
add_arg('test_images',       str,   None,               "The directory of images to be used for test.")
add_arg('test_list',         str,   None,               "The list file of images to be used for training.")
add_arg('init_model',        str,   'checkpoint/final', "The init model file of directory.")
add_arg('use_gpu',           bool,  True,               "Whether use GPU to train.")
add_arg('encoder_size',      int,   200,                "Encoder size.")
add_arg('decoder_size',      int,   128,                "Decoder size.")
add_arg('embedding_dim',     int,   128,                "Word vector dim.")
add_arg('num_classes',       int,   95,                 "Number classes.")
add_arg('beam_size',         int,   0,                  "If set beam size, will use beam search.")
add_arg('dynamic',           bool,  False,              "Whether to use dygraph.")
# yapf: enable


def main(FLAGS):
Q
qingqing01 已提交
48
    device = paddle.set_device("gpu" if FLAGS.use_gpu else "cpu")
Q
qingqing01 已提交
49
    paddle.disable_static(device) if FLAGS.dynamic else None
Q
qingqing01 已提交
50 51 52 53 54 55 56 57 58 59 60

    # yapf: disable
    inputs = [
        Input([None, 1, 48, 384], "float32", name="pixel"),
        Input([None, None], "int64", name="label_in")
    ]
    labels = [
        Input([None, None], "int64", name="label_out"),
        Input([None, None], "float32", name="mask")
    ]
    # yapf: enable
Q
qingqing01 已提交
61 62 63 64 65 66
    model = paddle.Model(
        Seq2SeqAttModel(
            encoder_size=FLAGS.encoder_size,
            decoder_size=FLAGS.decoder_size,
            emb_dim=FLAGS.embedding_dim,
            num_classes=FLAGS.num_classes),
Q
qingqing01 已提交
67
        inputs=inputs,
Q
qingqing01 已提交
68 69 70
        labels=labels)

    model.prepare(loss=WeightCrossEntropy(), metrics=SeqAccuracy())
Q
qingqing01 已提交
71 72 73
    model.load(FLAGS.init_model)

    test_dataset = data.test()
Q
qingqing01 已提交
74
    test_collate_fn = BatchCompose(
Q
qingqing01 已提交
75
        [data.Resize(), data.Normalize(), data.PadTarget()])
Q
qingqing01 已提交
76
    test_sampler = data.BatchSampler(
Q
qingqing01 已提交
77 78 79 80
        test_dataset,
        batch_size=FLAGS.batch_size,
        drop_last=False,
        shuffle=False)
Q
qingqing01 已提交
81
    test_loader = paddle.io.DataLoader(
Q
qingqing01 已提交
82 83 84 85 86 87 88 89 90
        test_dataset,
        batch_sampler=test_sampler,
        places=device,
        num_workers=0,
        return_list=True,
        collate_fn=test_collate_fn)

    model.evaluate(
        eval_data=test_loader,
91
        callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
Q
qingqing01 已提交
92 93 94 95


def beam_search(FLAGS):
    device = set_device("gpu" if FLAGS.use_gpu else "cpu")
Q
qingqing01 已提交
96
    paddle.disable_static(device) if FLAGS.dynamic else None
Q
qingqing01 已提交
97

Q
qingqing01 已提交
98
    # yapf: disable
Q
qingqing01 已提交
99
    inputs = [
Q
qingqing01 已提交
100 101
        Input([None, 1, 48, 384], "float32", name="pixel"),
        Input([None, None], "int64", name="label_in")
Q
qingqing01 已提交
102 103
    ]
    labels = [
Q
qingqing01 已提交
104 105
        Input([None, None], "int64", name="label_out"),
        Input([None, None], "float32", name="mask")
Q
qingqing01 已提交
106
    ]
Q
qingqing01 已提交
107 108 109 110 111 112 113 114 115
    # yapf: enable

    model = paddle.Model(
        Seq2SeqAttInferModel(
            encoder_size=FLAGS.encoder_size,
            decoder_size=FLAGS.decoder_size,
            emb_dim=FLAGS.embedding_dim,
            num_classes=FLAGS.num_classes,
            beam_size=FLAGS.beam_size),
Q
qingqing01 已提交
116
        inputs=inputs,
Q
qingqing01 已提交
117 118 119
        labels=labels)

    model.prepare(loss_function=None, metrics=SeqBeamAccuracy())
Q
qingqing01 已提交
120 121 122
    model.load(FLAGS.init_model)

    test_dataset = data.test()
Q
qingqing01 已提交
123
    test_collate_fn = BatchCompose(
Q
qingqing01 已提交
124
        [data.Resize(), data.Normalize(), data.PadTarget()])
Q
qingqing01 已提交
125
    test_sampler = data.BatchSampler(
Q
qingqing01 已提交
126 127 128 129
        test_dataset,
        batch_size=FLAGS.batch_size,
        drop_last=False,
        shuffle=False)
Q
qingqing01 已提交
130
    test_loader = paddle.io.DataLoader(
Q
qingqing01 已提交
131 132 133 134 135 136 137 138 139
        test_dataset,
        batch_sampler=test_sampler,
        places=device,
        num_workers=0,
        return_list=True,
        collate_fn=test_collate_fn)

    model.evaluate(
        eval_data=test_loader,
140
        callbacks=[LoggerCallBack(10, 2, FLAGS.batch_size)])
Q
qingqing01 已提交
141 142 143 144 145 146 147 148 149


if __name__ == '__main__':
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
    if FLAGS.beam_size:
        beam_search(FLAGS)
    else:
        main(FLAGS)