infer.py 2.8 KB
Newer Older
P
peterzhang2029 已提交
1
import click
P
peterzhang2029 已提交
2
import gzip
3 4

import paddle.v2 as paddle
P
peterzhang2029 已提交
5
from network_conf import Model
P
peterzhang2029 已提交
6
from reader import DataGenerator
P
peterzhang2029 已提交
7
from decoder import ctc_greedy_decoder
P
peterzhang2029 已提交
8
from utils import get_file_list, load_dict, load_reverse_dict
P
peterzhang2029 已提交
9 10


P
peterzhang2029 已提交
11
def infer_batch(inferer, test_batch, labels, reversed_char_dict):
P
peterzhang2029 已提交
12 13 14 15 16 17 18
    infer_results = inferer.infer(input=test_batch)
    num_steps = len(infer_results) // len(test_batch)
    probs_split = [
        infer_results[i * num_steps:(i + 1) * num_steps]
        for i in xrange(0, len(test_batch))
    ]
    results = []
P
peterzhang2029 已提交
19
    # Best path decode.
P
peterzhang2029 已提交
20 21
    for i, probs in enumerate(probs_split):
        output_transcription = ctc_greedy_decoder(
P
peterzhang2029 已提交
22
            probs_seq=probs, vocabulary=reversed_char_dict)
P
peterzhang2029 已提交
23 24 25
        results.append(output_transcription)

    for result, label in zip(results, labels):
26 27
        print("\nOutput Transcription: %s\nTarget Transcription: %s" %
              (result, label))
P
peterzhang2029 已提交
28 29


P
peterzhang2029 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42
@click.command('infer')
@click.option(
    "--model_path", type=str, required=True, help=("The path of saved model."))
@click.option(
    "--image_shape",
    type=str,
    required=True,
    help=("The fixed size for image dataset (format is like: '173,46')."))
@click.option(
    "--batch_size",
    type=int,
    default=10,
    help=("The number of examples in one batch (default: 10)."))
P
peterzhang2029 已提交
43 44 45 46 47
@click.option(
    "--label_dict_path",
    type=str,
    required=True,
    help=("The path of label dictionary. "))
P
peterzhang2029 已提交
48 49 50 51 52 53
@click.option(
    "--infer_file_list_path",
    type=str,
    required=True,
    help=("The path of the file which contains "
          "path list of image files for inference."))
P
peterzhang2029 已提交
54 55 56
def infer(model_path, image_shape, batch_size, label_dict_path,
          infer_file_list_path):

P
peterzhang2029 已提交
57
    image_shape = tuple(map(int, image_shape.split(',')))
P
peterzhang2029 已提交
58
    infer_file_list = get_file_list(infer_file_list_path)
P
peterzhang2029 已提交
59 60 61 62

    char_dict = load_dict(label_dict_path)
    reversed_char_dict = load_reverse_dict(label_dict_path)
    dict_size = len(char_dict)
P
peterzhang2029 已提交
63
    data_generator = DataGenerator(char_dict=char_dict, image_shape=image_shape)
P
peterzhang2029 已提交
64

P
peterzhang2029 已提交
65
    paddle.init(use_gpu=True, trainer_count=1)
P
peterzhang2029 已提交
66
    parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
P
peterzhang2029 已提交
67
    model = Model(dict_size, image_shape, is_infer=True)
P
peterzhang2029 已提交
68 69 70 71 72
    inferer = paddle.inference.Inference(
        output_layer=model.log_probs, parameters=parameters)

    test_batch = []
    labels = []
73 74
    for i, (image, label
            ) in enumerate(data_generator.infer_reader(infer_file_list)()):
P
peterzhang2029 已提交
75 76 77
        test_batch.append([image])
        labels.append(label)
        if len(test_batch) == batch_size:
P
peterzhang2029 已提交
78
            infer_batch(inferer, test_batch, labels, reversed_char_dict)
P
peterzhang2029 已提交
79 80 81
            test_batch = []
            labels = []
        if test_batch:
P
peterzhang2029 已提交
82
            infer_batch(inferer, test_batch, labels, reversed_char_dict)
83 84 85


if __name__ == "__main__":
P
peterzhang2029 已提交
86
    infer()