infer.py 5.7 KB
Newer Older
1
from __future__ import print_function
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer
4
import paddle.fluid.profiler as profiler
W
wanghaoshuang 已提交
5
from crnn_ctc_model import ctc_infer
6
from attention_model import attention_infer
W
wanghaoshuang 已提交
7
import numpy as np
8
import data_reader
W
wanghaoshuang 已提交
9 10 11
import argparse
import functools
import os
12
import time
W
wanghaoshuang 已提交
13 14 15 16

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
17
add_arg('model',    str,   "crnn_ctc",           "Which type of network to be used. 'crnn_ctc' or 'attention'")
W
wanghaoshuang 已提交
18 19 20
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
21
add_arg('dict',               str,  None,   "The dictionary. The result of inference will be index sequence if the dictionary was None.")
W
wanghaoshuang 已提交
22
add_arg('use_gpu',            bool,  True,      "Whether use GPU to infer.")
23 24 25 26
add_arg('iterations',         int,  0,      "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile',            bool, False,  "Whether to use profiling.")
add_arg('skip_batch_num',     int,  0,      "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size',         int,  1,      "The minibatch size.")
W
wanghaoshuang 已提交
27 28
# yapf: enable

W
wanghaoshuang 已提交
29

30
def inference(args):
W
wanghaoshuang 已提交
31
    """OCR inference"""
32 33
    if args.model == "crnn_ctc":
        infer = ctc_infer
W
whs 已提交
34
        get_feeder_data = get_ctc_feeder_for_infer
35 36 37 38 39
    else:
        infer = attention_infer
        get_feeder_data = get_attention_feeder_for_infer
    eos = 1
    sos = 0
W
wanghaoshuang 已提交
40 41 42 43
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
44
    ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
W
wanghaoshuang 已提交
45
    # data reader
W
wanghaoshuang 已提交
46
    infer_reader = data_reader.inference(
47
        batch_size=args.batch_size,
W
wanghaoshuang 已提交
48
        infer_images_dir=args.input_images_dir,
49
        infer_list_file=args.input_images_list,
50 51
        cycle=True if args.iterations > 0 else False,
        model=args.model)
W
wanghaoshuang 已提交
52
    # prepare environment
W
wanghaoshuang 已提交
53
    place = fluid.CPUPlace()
54
    if args.use_gpu:
W
wanghaoshuang 已提交
55
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
56

W
wanghaoshuang 已提交
57 58 59
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

60 61 62 63 64 65 66
    # load dictionary
    dict_map = None
    if args.dict is not None and os.path.isfile(args.dict):
        dict_map = {}
        with open(args.dict) as dict_file:
            for i, word in enumerate(dict_file):
                dict_map[i] = word.strip()
67
        print("Loaded dict from %s" % args.dict)
68

W
wanghaoshuang 已提交
69 70 71 72
    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
73 74
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
75
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
76
    print("Init model from: %s." % args.model_path)
W
wanghaoshuang 已提交
77

78 79
    batch_times = []
    iters = 0
W
wanghaoshuang 已提交
80
    for data in infer_reader():
W
whs 已提交
81
        feed_dict = get_feeder_data(data, place)
82 83 84 85 86 87 88 89
        if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
            break
        if iters < args.skip_batch_num:
            print("Warm-up itaration")
        if iters == args.skip_batch_num:
            profiler.reset_profiler()

        start = time.time()
W
wanghaoshuang 已提交
90
        result = exe.run(fluid.default_main_program(),
91 92
                         feed=feed_dict,
                         fetch_list=[ids],
W
wanghaoshuang 已提交
93
                         return_numpy=False)
94
        indexes = prune(np.array(result[0]).flatten(), 0, 1)
95 96 97
        batch_time = time.time() - start
        fps = args.batch_size / batch_time
        batch_times.append(batch_time)
98
        if dict_map is not None:
99
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
100 101 102
                iters,
                batch_time,
                fps,
103
                [dict_map[index] for index in indexes], ))
104
        else:
105
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
106 107 108
                iters,
                batch_time,
                fps,
109
                indexes, ))
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

        iters += 1

    latencies = batch_times[args.skip_batch_num:]
    latency_avg = np.average(latencies)
    latency_pc99 = np.percentile(latencies, 99)
    fpses = np.divide(args.batch_size, latencies)
    fps_avg = np.average(fpses)
    fps_pc99 = np.percentile(fpses, 1)

    # Benchmark output
    print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
    print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                             latency_pc99))
    print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
W
wanghaoshuang 已提交
125 126


127 128 129 130 131 132 133 134 135 136 137
def prune(words, sos, eos):
    """Remove unused tokens in prediction result."""
    start_index = 0
    end_index = len(words)
    if sos in words:
        start_index = np.where(words == sos)[0][0] + 1
    if eos in words:
        end_index = np.where(words == eos)[0][0]
    return words[start_index:end_index]


W
wanghaoshuang 已提交
138
def main():
W
wanghaoshuang 已提交
139 140
    args = parser.parse_args()
    print_arguments(args)
141 142 143
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
144
                inference(args)
145 146
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
147
                inference(args)
148
    else:
149
        inference(args)
W
wanghaoshuang 已提交
150

W
wanghaoshuang 已提交
151

W
wanghaoshuang 已提交
152 153
if __name__ == "__main__":
    main()