infer.py 5.0 KB
Newer Older
W
wanghaoshuang 已提交
1
import paddle.v2 as paddle
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3
import paddle.fluid.profiler as profiler
W
wanghaoshuang 已提交
4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
W
wanghaoshuang 已提交
5
from crnn_ctc_model import ctc_infer
W
wanghaoshuang 已提交
6
import numpy as np
W
wanghaoshuang 已提交
7
import ctc_reader
W
wanghaoshuang 已提交
8 9 10
import argparse
import functools
import os
11
import time
W
wanghaoshuang 已提交
12 13 14 15 16 17 18

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
19
add_arg('dict',               str,  None,   "The dictionary. The result of inference will be index sequence if the dictionary was None.")
W
wanghaoshuang 已提交
20
add_arg('use_gpu',            bool,  True,      "Whether use GPU to infer.")
21 22 23 24
add_arg('iterations',         int,  0,      "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile',            bool, False,  "Whether to use profiling.")
add_arg('skip_batch_num',     int,  0,      "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size',         int,  1,      "The minibatch size.")
W
wanghaoshuang 已提交
25 26
# yapf: enable

W
wanghaoshuang 已提交
27 28

def inference(args, infer=ctc_infer, data_reader=ctc_reader):
W
wanghaoshuang 已提交
29 30 31 32 33
    """OCR inference"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
34 35
    sequence = infer(
        images, num_classes, use_cudnn=True if args.use_gpu else False)
W
wanghaoshuang 已提交
36
    # data reader
W
wanghaoshuang 已提交
37
    infer_reader = data_reader.inference(
38
        batch_size=args.batch_size,
W
wanghaoshuang 已提交
39
        infer_images_dir=args.input_images_dir,
40 41
        infer_list_file=args.input_images_list,
        cycle=True if args.iterations > 0 else False)
W
wanghaoshuang 已提交
42
    # prepare environment
W
wanghaoshuang 已提交
43
    place = fluid.CPUPlace()
44
    if args.use_gpu:
W
wanghaoshuang 已提交
45
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
46

W
wanghaoshuang 已提交
47 48 49
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

50 51 52 53 54 55 56 57 58
    # load dictionary
    dict_map = None
    if args.dict is not None and os.path.isfile(args.dict):
        dict_map = {}
        with open(args.dict) as dict_file:
            for i, word in enumerate(dict_file):
                dict_map[i] = word.strip()
        print "Loaded dict from %s" % args.dict

W
wanghaoshuang 已提交
59 60 61 62
    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
63 64
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
65 66 67
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
    print "Init model from: %s." % args.model_path

68 69
    batch_times = []
    iters = 0
W
wanghaoshuang 已提交
70
    for data in infer_reader():
71 72 73 74 75 76 77 78
        if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
            break
        if iters < args.skip_batch_num:
            print("Warm-up itaration")
        if iters == args.skip_batch_num:
            profiler.reset_profiler()

        start = time.time()
W
wanghaoshuang 已提交
79 80 81
        result = exe.run(fluid.default_main_program(),
                         feed=get_feeder_data(
                             data, place, need_label=False),
W
wanghaoshuang 已提交
82 83
                         fetch_list=[sequence],
                         return_numpy=False)
84 85 86
        batch_time = time.time() - start
        fps = args.batch_size / batch_time
        batch_times.append(batch_time)
87 88
        indexes = np.array(result[0]).flatten()
        if dict_map is not None:
89 90 91 92 93
            print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
                iters,
                batch_time,
                fps,
                [dict_map[index] for index in indexes], )
94
        else:
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
                iters,
                batch_time,
                fps,
                indexes, )

        iters += 1

    latencies = batch_times[args.skip_batch_num:]
    latency_avg = np.average(latencies)
    latency_pc99 = np.percentile(latencies, 99)
    fpses = np.divide(args.batch_size, latencies)
    fps_avg = np.average(fpses)
    fps_pc99 = np.percentile(fpses, 1)

    # Benchmark output
    print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
    print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                             latency_pc99))
    print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
W
wanghaoshuang 已提交
115 116 117


def main():
W
wanghaoshuang 已提交
118 119
    args = parser.parse_args()
    print_arguments(args)
120 121 122 123 124 125 126 127 128
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
                inference(args, data_reader=ctc_reader)
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
                inference(args, data_reader=ctc_reader)
    else:
        inference(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
129

W
wanghaoshuang 已提交
130

W
wanghaoshuang 已提交
131 132
if __name__ == "__main__":
    main()