eval.py 2.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2
import paddle.v2 as paddle
import paddle.fluid as fluid
W
wanghaoshuang 已提交
3 4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
W
wanghaoshuang 已提交
5 6
from crnn_ctc_model import ctc_eval
import ctc_reader
W
wanghaoshuang 已提交
7 8 9
import argparse
import functools
import os
W
wanghaoshuang 已提交
10

W
wanghaoshuang 已提交
11 12 13 14 15 16
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
W
wanghaoshuang 已提交
17 18
add_arg('use_gpu',            bool,  True,      "Whether use GPU to eval.")
# yapf: enable
W
wanghaoshuang 已提交
19 20


W
wanghaoshuang 已提交
21
def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
W
wanghaoshuang 已提交
22 23 24 25 26 27 28 29 30 31
    """OCR inference"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
    label = fluid.layers.data(
        name='label', shape=[1], dtype='int32', lod_level=1)
    evaluator, cost = eval(images, label, num_classes)

    # data reader
W
wanghaoshuang 已提交
32 33 34
    test_reader = data_reader.test(
        test_images_dir=args.input_images_dir,
        test_list_file=args.input_images_list)
W
wanghaoshuang 已提交
35

W
wanghaoshuang 已提交
36
    # prepare environment
W
wanghaoshuang 已提交
37
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
38 39
    if use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
40

W
wanghaoshuang 已提交
41 42
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
W
wanghaoshuang 已提交
43 44 45 46 47

    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
48 49
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
50 51 52
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
    print "Init model from: %s." % args.model_path

W
wanghaoshuang 已提交
53 54 55 56
    evaluator.reset(exe)
    count = 0
    for data in test_reader():
        count += 1
W
wanghaoshuang 已提交
57
        exe.run(fluid.default_main_program(), feed=get_feeder_data(data, place))
W
wanghaoshuang 已提交
58
    avg_distance, avg_seq_error = evaluator.eval(exe)
W
wanghaoshuang 已提交
59 60 61
    print "Read %d samples; avg_distance: %s; avg_seq_error: %s" % (
        count, avg_distance, avg_seq_error)

W
wanghaoshuang 已提交
62 63

def main():
W
wanghaoshuang 已提交
64 65 66
    args = parser.parse_args()
    print_arguments(args)
    evaluate(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
67 68 69 70


if __name__ == "__main__":
    main()