infer.py 5.7 KB
Newer Older
1
from __future__ import print_function
W
wanghaoshuang 已提交
2
import paddle.v2 as paddle
W
wanghaoshuang 已提交
3
import paddle.fluid as fluid
4
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer
5
import paddle.fluid.profiler as profiler
W
wanghaoshuang 已提交
6
from crnn_ctc_model import ctc_infer
7
from attention_model import attention_infer
W
wanghaoshuang 已提交
8
import numpy as np
9
import data_reader
W
wanghaoshuang 已提交
10 11 12
import argparse
import functools
import os
13
import time
W
wanghaoshuang 已提交
14 15 16 17

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
18
add_arg('model',    str,   "crnn_ctc",           "Which type of network to be used. 'crnn_ctc' or 'attention'")
W
wanghaoshuang 已提交
19 20 21
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
22
add_arg('dict',               str,  None,   "The dictionary. The result of inference will be index sequence if the dictionary was None.")
W
wanghaoshuang 已提交
23
add_arg('use_gpu',            bool,  True,      "Whether use GPU to infer.")
24 25 26 27
add_arg('iterations',         int,  0,      "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile',            bool, False,  "Whether to use profiling.")
add_arg('skip_batch_num',     int,  0,      "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size',         int,  1,      "The minibatch size.")
W
wanghaoshuang 已提交
28 29
# yapf: enable

W
wanghaoshuang 已提交
30

31
def inference(args):
W
wanghaoshuang 已提交
32
    """OCR inference"""
33 34 35 36 37 38 39 40
    if args.model == "crnn_ctc":
        infer = ctc_infer
        get_feeder_data = get_ctc_feeder_data
    else:
        infer = attention_infer
        get_feeder_data = get_attention_feeder_for_infer
    eos = 1
    sos = 0
W
wanghaoshuang 已提交
41 42 43 44
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
45
    ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
W
wanghaoshuang 已提交
46
    # data reader
W
wanghaoshuang 已提交
47
    infer_reader = data_reader.inference(
48
        batch_size=args.batch_size,
W
wanghaoshuang 已提交
49
        infer_images_dir=args.input_images_dir,
50
        infer_list_file=args.input_images_list,
51 52
        cycle=True if args.iterations > 0 else False,
        model=args.model)
W
wanghaoshuang 已提交
53
    # prepare environment
W
wanghaoshuang 已提交
54
    place = fluid.CPUPlace()
55
    if args.use_gpu:
W
wanghaoshuang 已提交
56
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
57

W
wanghaoshuang 已提交
58 59 60
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

61 62 63 64 65 66 67
    # load dictionary
    dict_map = None
    if args.dict is not None and os.path.isfile(args.dict):
        dict_map = {}
        with open(args.dict) as dict_file:
            for i, word in enumerate(dict_file):
                dict_map[i] = word.strip()
68
        print("Loaded dict from %s" % args.dict)
69

W
wanghaoshuang 已提交
70 71 72 73
    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
74 75
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
76
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
77
    print("Init model from: %s." % args.model_path)
W
wanghaoshuang 已提交
78

79 80
    batch_times = []
    iters = 0
W
wanghaoshuang 已提交
81
    for data in infer_reader():
82
        feed_dict = get_feeder_data(data, place)
83 84 85 86 87 88 89 90
        if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
            break
        if iters < args.skip_batch_num:
            print("Warm-up itaration")
        if iters == args.skip_batch_num:
            profiler.reset_profiler()

        start = time.time()
W
wanghaoshuang 已提交
91
        result = exe.run(fluid.default_main_program(),
92 93
                         feed=feed_dict,
                         fetch_list=[ids],
W
wanghaoshuang 已提交
94
                         return_numpy=False)
95
        indexes = prune(np.array(result[0]).flatten(), 0, 1)
96 97 98
        batch_time = time.time() - start
        fps = args.batch_size / batch_time
        batch_times.append(batch_time)
99
        if dict_map is not None:
100
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
101 102 103
                iters,
                batch_time,
                fps,
104
                [dict_map[index] for index in indexes], ))
105
        else:
106
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
107 108 109
                iters,
                batch_time,
                fps,
110
                indexes, ))
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

        iters += 1

    latencies = batch_times[args.skip_batch_num:]
    latency_avg = np.average(latencies)
    latency_pc99 = np.percentile(latencies, 99)
    fpses = np.divide(args.batch_size, latencies)
    fps_avg = np.average(fpses)
    fps_pc99 = np.percentile(fpses, 1)

    # Benchmark output
    print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
    print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                             latency_pc99))
    print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
W
wanghaoshuang 已提交
126 127


128 129 130 131 132 133 134 135 136 137 138
def prune(words, sos, eos):
    """Remove unused tokens in prediction result."""
    start_index = 0
    end_index = len(words)
    if sos in words:
        start_index = np.where(words == sos)[0][0] + 1
    if eos in words:
        end_index = np.where(words == eos)[0][0]
    return words[start_index:end_index]


W
wanghaoshuang 已提交
139
def main():
W
wanghaoshuang 已提交
140 141
    args = parser.parse_args()
    print_arguments(args)
142 143 144
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
145
                inference(args)
146 147
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
148
                inference(args)
149
    else:
150
        inference(args)
W
wanghaoshuang 已提交
151

W
wanghaoshuang 已提交
152

W
wanghaoshuang 已提交
153 154
if __name__ == "__main__":
    main()