eval.py 2.4 KB
Newer Older
W
wanghaoshuang 已提交
1 2
import paddle.v2 as paddle
import paddle.fluid as fluid
W
wanghaoshuang 已提交
3 4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
W
wanghaoshuang 已提交
5 6
from crnn_ctc_model import ctc_eval
import ctc_reader
W
wanghaoshuang 已提交
7 8 9
import argparse
import functools
import os
W
wanghaoshuang 已提交
10

W
wanghaoshuang 已提交
11 12 13 14 15 16
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
17 18
add_arg('use_gpu',            bool, True,   "Whether use GPU to eval.")
add_arg('use_mkldnn',         bool, False,  "Whether to use MKLDNN to eval.")
W
wanghaoshuang 已提交
19
# yapf: enable
W
wanghaoshuang 已提交
20 21


W
wanghaoshuang 已提交
22
def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
W
wanghaoshuang 已提交
23 24 25 26 27 28 29
    """OCR inference"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
    label = fluid.layers.data(
        name='label', shape=[1], dtype='int32', lod_level=1)
30 31
    evaluator, cost = eval(images, label, num_classes, args.use_mkldnn, True
                           if args.use_gpu else False)
W
wanghaoshuang 已提交
32 33

    # data reader
W
wanghaoshuang 已提交
34 35 36
    test_reader = data_reader.test(
        test_images_dir=args.input_images_dir,
        test_list_file=args.input_images_list)
W
wanghaoshuang 已提交
37

W
wanghaoshuang 已提交
38
    # prepare environment
W
wanghaoshuang 已提交
39
    place = fluid.CPUPlace()
40
    if args.use_gpu:
W
wanghaoshuang 已提交
41
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
42

W
wanghaoshuang 已提交
43 44
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
W
wanghaoshuang 已提交
45 46 47 48 49

    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
50 51
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
52 53 54
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
    print "Init model from: %s." % args.model_path

W
wanghaoshuang 已提交
55 56 57 58
    evaluator.reset(exe)
    count = 0
    for data in test_reader():
        count += 1
W
wanghaoshuang 已提交
59
        exe.run(fluid.default_main_program(), feed=get_feeder_data(data, place))
W
wanghaoshuang 已提交
60
    avg_distance, avg_seq_error = evaluator.eval(exe)
W
wanghaoshuang 已提交
61 62 63
    print "Read %d samples; avg_distance: %s; avg_seq_error: %s" % (
        count, avg_distance, avg_seq_error)

W
wanghaoshuang 已提交
64 65

def main():
W
wanghaoshuang 已提交
66 67 68
    args = parser.parse_args()
    print_arguments(args)
    evaluate(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
69 70 71 72


if __name__ == "__main__":
    main()