infer.py 2.0 KB
Newer Older
P
peterzhang2029 已提交
1 2 3
import logging
import argparse
import gzip
4 5

import paddle.v2 as paddle
P
peterzhang2029 已提交
6 7 8 9 10
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset
from decoder import ctc_greedy_decoder


11
def infer_batch(inferer, test_batch, labels):
P
peterzhang2029 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
    infer_results = inferer.infer(input=test_batch)
    num_steps = len(infer_results) // len(test_batch)
    probs_split = [
        infer_results[i * num_steps:(i + 1) * num_steps]
        for i in xrange(0, len(test_batch))
    ]

    results = []
    # best path decode
    for i, probs in enumerate(probs_split):
        output_transcription = ctc_greedy_decoder(
            probs_seq=probs, vocabulary=AsciiDic().id2word())
        results.append(output_transcription)

    for result, label in zip(results, labels):
27 28
        print("\nOutput Transcription: %s\nTarget Transcription: %s" %
              (result, label))
P
peterzhang2029 已提交
29 30


31
def infer(model_path, image_shape, batch_size, infer_file_list):
P
peterzhang2029 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
    image_shape = tuple(map(int, image_shape.split(',')))
    infer_generator = get_file_list(infer_file_list)

    dataset = ImageDataset(None, None, infer_generator, image_shape, True)

    paddle.init(use_gpu=True, trainer_count=4)
    parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
    model = Model(AsciiDic().size(), image_shape, is_infer=True)
    inferer = paddle.inference.Inference(
        output_layer=model.log_probs, parameters=parameters)

    test_batch = []
    labels = []
    for i, (image, label) in enumerate(dataset.infer()):
        test_batch.append([image])
        labels.append(label)
        if len(test_batch) == batch_size:
49
            infer_batch(inferer, test_batch, labels)
P
peterzhang2029 已提交
50 51 52
            test_batch = []
            labels = []
        if test_batch:
53 54 55 56 57 58 59 60 61 62
            infer_batch(inferer, test_batch, labels)


if __name__ == "__main__":
    model_path = "model.ctc-pass-9-batch-150-test.tar.gz"
    image_shape = "173,46"
    batch_size = 50
    infer_file_list = 'data/test_data/Challenge2_Test_Task3_GT.txt'

    infer(model_path, image_shape, batch_size, infer_file_list)