infer.py 6.4 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from __future__ import print_function
W
wanghaoshuang 已提交
15
import paddle.fluid as fluid
W
whs 已提交
16
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer, get_ctc_feeder_for_infer
17
from utility import check_gpu, check_version
18
import paddle.fluid.profiler as profiler
W
wanghaoshuang 已提交
19
from crnn_ctc_model import ctc_infer
20
from attention_model import attention_infer
W
wanghaoshuang 已提交
21
import numpy as np
22
import data_reader
W
wanghaoshuang 已提交
23 24 25
import argparse
import functools
import os
26
import time
W
wanghaoshuang 已提交
27 28 29 30

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
31
add_arg('model',    str,   "crnn_ctc",           "Which type of network to be used. 'crnn_ctc' or 'attention'")
W
wanghaoshuang 已提交
32 33 34
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
35
add_arg('dict',               str,  None,   "The dictionary. The result of inference will be index sequence if the dictionary was None.")
W
wanghaoshuang 已提交
36
add_arg('use_gpu',            bool,  True,      "Whether use GPU to infer.")
37 38 39 40
add_arg('iterations',         int,  0,      "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile',            bool, False,  "Whether to use profiling.")
add_arg('skip_batch_num',     int,  0,      "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size',         int,  1,      "The minibatch size.")
W
wanghaoshuang 已提交
41 42
# yapf: enable

W
wanghaoshuang 已提交
43

44
def inference(args):
W
wanghaoshuang 已提交
45
    """OCR inference"""
46 47
    if args.model == "crnn_ctc":
        infer = ctc_infer
W
whs 已提交
48
        get_feeder_data = get_ctc_feeder_for_infer
49 50 51 52 53
    else:
        infer = attention_infer
        get_feeder_data = get_attention_feeder_for_infer
    eos = 1
    sos = 0
W
wanghaoshuang 已提交
54 55 56 57
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
58
    ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
W
wanghaoshuang 已提交
59
    # data reader
W
wanghaoshuang 已提交
60
    infer_reader = data_reader.inference(
61
        batch_size=args.batch_size,
W
wanghaoshuang 已提交
62
        infer_images_dir=args.input_images_dir,
63
        infer_list_file=args.input_images_list,
64 65
        cycle=True if args.iterations > 0 else False,
        model=args.model)
W
wanghaoshuang 已提交
66
    # prepare environment
W
wanghaoshuang 已提交
67
    place = fluid.CPUPlace()
68
    if args.use_gpu:
W
wanghaoshuang 已提交
69
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
70

W
wanghaoshuang 已提交
71 72 73
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

74 75 76 77 78 79 80
    # load dictionary
    dict_map = None
    if args.dict is not None and os.path.isfile(args.dict):
        dict_map = {}
        with open(args.dict) as dict_file:
            for i, word in enumerate(dict_file):
                dict_map[i] = word.strip()
81
        print("Loaded dict from %s" % args.dict)
82

W
wanghaoshuang 已提交
83 84 85 86
    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
87 88
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
89
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
90
    print("Init model from: %s." % args.model_path)
W
wanghaoshuang 已提交
91

92 93
    batch_times = []
    iters = 0
W
wanghaoshuang 已提交
94
    for data in infer_reader():
W
whs 已提交
95
        feed_dict = get_feeder_data(data, place)
96 97 98 99 100 101 102 103
        if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
            break
        if iters < args.skip_batch_num:
            print("Warm-up itaration")
        if iters == args.skip_batch_num:
            profiler.reset_profiler()

        start = time.time()
W
wanghaoshuang 已提交
104
        result = exe.run(fluid.default_main_program(),
105 106
                         feed=feed_dict,
                         fetch_list=[ids],
W
wanghaoshuang 已提交
107
                         return_numpy=False)
108
        indexes = prune(np.array(result[0]).flatten(), 0, 1)
109 110 111
        batch_time = time.time() - start
        fps = args.batch_size / batch_time
        batch_times.append(batch_time)
112
        if dict_map is not None:
113
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
114 115 116
                iters,
                batch_time,
                fps,
117
                [dict_map[index] for index in indexes], ))
118
        else:
119
            print("Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
120 121 122
                iters,
                batch_time,
                fps,
123
                indexes, ))
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

        iters += 1

    latencies = batch_times[args.skip_batch_num:]
    latency_avg = np.average(latencies)
    latency_pc99 = np.percentile(latencies, 99)
    fpses = np.divide(args.batch_size, latencies)
    fps_avg = np.average(fpses)
    fps_pc99 = np.percentile(fpses, 1)

    # Benchmark output
    print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
    print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                             latency_pc99))
    print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
W
wanghaoshuang 已提交
139 140


141 142 143 144 145 146 147 148 149 150 151
def prune(words, sos, eos):
    """Remove unused tokens in prediction result."""
    start_index = 0
    end_index = len(words)
    if sos in words:
        start_index = np.where(words == sos)[0][0] + 1
    if eos in words:
        end_index = np.where(words == eos)[0][0]
    return words[start_index:end_index]


W
wanghaoshuang 已提交
152
def main():
W
wanghaoshuang 已提交
153 154
    args = parser.parse_args()
    print_arguments(args)
W
whs 已提交
155
    check_gpu(args.use_gpu)
156
    check_version()
157 158 159
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
160
                inference(args)
161 162
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
163
                inference(args)
164
    else:
165
        inference(args)
W
wanghaoshuang 已提交
166

W
wanghaoshuang 已提交
167

W
wanghaoshuang 已提交
168 169
if __name__ == "__main__":
    main()