infer.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
import paddle.v2 as paddle
import audio_data_utils
import argparse
from model import deep_speech2
import gzip
from itertools import groupby

parser = argparse.ArgumentParser(
    description='Simpled version of DeepSpeech2 inference.')
parser.add_argument(
    "--num_samples", default=10, type=int, help="Number of inference samples.")
parser.add_argument(
    "--num_conv_layers", default=2, type=int, help="Convolution layer number.")
parser.add_argument(
    "--num_rnn_layers", default=3, type=int, help="RNN layer number.")
parser.add_argument(
    "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.")
parser.add_argument(
    "--use_gpu", default=True, type=bool, help="Use gpu or not.")
args = parser.parse_args()


def remove_duplicate_and_blank(id_list, blank_id):
    # remove consecutive duplicate tokens
    id_list = [x[0] for x in groupby(id_list)]
    # remove blank
    return [id for id in id_list if id != blank_id]


def max_infer():
    # create network config
    _, vocab_list = audio_data_utils.get_vocabulary()
    dict_size = len(vocab_list)
    audio_data = paddle.layer.data(
        name="audio_spectrogram",
        height=161,
        width=1000,
        type=paddle.data_type.dense_vector(161000))
    text_data = paddle.layer.data(
        name="transcript_text",
        type=paddle.data_type.integer_value_sequence(dict_size))
    _, max_id = deep_speech2(
        audio_data=audio_data,
        text_data=text_data,
        dict_size=dict_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_size=args.rnn_layer_size)

    # load parameters
    parameters = paddle.parameters.Parameters.from_tar(
        gzip.open("params.tar.gz"))

    # prepare infer data
    feeding = {
        "audio_spectrogram": 0,
        "transcript_text": 1,
    }
    test_batch_reader = audio_data_utils.padding_batch_reader(
        paddle.batch(
            audio_data_utils.reader_creator(
                manifest_path="./libri.manifest.test", sort_by_duration=False),
            batch_size=args.num_samples),
        padding=[-1, 1000])
    infer_data = test_batch_reader().next()

    # run inference
    max_id_results = paddle.infer(
        output_layer=max_id,
        parameters=parameters,
        input=infer_data,
        field=['id'])

    # postprocess
    instance_length = len(max_id_results) / args.num_samples
    instance_list = [
        max_id_results[i:i + instance_length]
        for i in xrange(0, args.num_samples)
    ]
    for i, instance in enumerate(instance_list):
        id_list = remove_duplicate_and_blank(instance, dict_size)
        output_transcript = ''.join([vocab_list[id] for id in id_list])
        target_transcript = ''.join([vocab_list[id] for id in infer_data[i][1]])
        print("Target Transcript: %s \nOutput Transcript: %s \n" %
              (target_transcript, output_transcript))


def main():
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    max_infer()


if __name__ == '__main__':
    main()