inference.py 2.2 KB
Newer Older
W
wanghaoshuang 已提交
1
import paddle.v2 as paddle
W
wanghaoshuang 已提交
2 3
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
W
wanghaoshuang 已提交
4
from crnn_ctc_model import ctc_infer
W
wanghaoshuang 已提交
5
import numpy as np
W
wanghaoshuang 已提交
6
import ctc_reader
W
wanghaoshuang 已提交
7 8 9 10 11 12 13 14 15 16
import argparse
import functools
import os

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path',         str,  None,   "The model path to be used for inference.")
add_arg('input_images_dir',   str,  None,   "The directory of images.")
add_arg('input_images_list',  str,  None,   "The list file of images.")
W
wanghaoshuang 已提交
17 18 19
add_arg('use_gpu',            bool,  True,      "Whether use GPU to infer.")
# yapf: enable

W
wanghaoshuang 已提交
20 21

def inference(args, infer=ctc_infer, data_reader=ctc_reader):
W
wanghaoshuang 已提交
22 23 24 25 26
    """OCR inference"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
W
wanghaoshuang 已提交
27
    sequence = infer(images, num_classes)
W
wanghaoshuang 已提交
28
    # data reader
W
wanghaoshuang 已提交
29 30 31
    infer_reader = data_reader.inference(
        infer_images_dir=args.input_images_dir,
        infer_list_file=args.input_images_list)
W
wanghaoshuang 已提交
32
    # prepare environment
W
wanghaoshuang 已提交
33
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
34 35
    if use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
36

W
wanghaoshuang 已提交
37 38 39
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

W
wanghaoshuang 已提交
40 41 42 43
    # load init model
    model_dir = args.model_path
    model_file_name = None
    if not os.path.isdir(args.model_path):
W
wanghaoshuang 已提交
44 45
        model_dir = os.path.dirname(args.model_path)
        model_file_name = os.path.basename(args.model_path)
W
wanghaoshuang 已提交
46 47 48
    fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
    print "Init model from: %s." % args.model_path

W
wanghaoshuang 已提交
49
    for data in infer_reader():
W
wanghaoshuang 已提交
50 51 52
        result = exe.run(fluid.default_main_program(),
                         feed=get_feeder_data(
                             data, place, need_label=False),
W
wanghaoshuang 已提交
53 54 55
                         fetch_list=[sequence],
                         return_numpy=False)
        print "result: %s" % (np.array(result[0]).flatten(), )
W
wanghaoshuang 已提交
56 57 58


def main():
W
wanghaoshuang 已提交
59 60 61
    args = parser.parse_args()
    print_arguments(args)
    inference(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
62

W
wanghaoshuang 已提交
63

W
wanghaoshuang 已提交
64 65
if __name__ == "__main__":
    main()