infer.py 3.3 KB
Newer Older
X
Xinghai Sun 已提交
1 2 3 4
"""
   Inference for a simplifed version of Baidu DeepSpeech2 model.
"""

5
import paddle.v2 as paddle
X
Xinghai Sun 已提交
6
from itertools import groupby
7 8
import argparse
import gzip
X
Xinghai Sun 已提交
9 10
import audio_data_utils
from model import deep_speech2
11 12

parser = argparse.ArgumentParser(
X
Xinghai Sun 已提交
13
    description='Simplified version of DeepSpeech2 inference.')
14
parser.add_argument(
X
Xinghai Sun 已提交
15 16 17 18
    "--num_samples",
    default=10,
    type=int,
    help="Number of samples for inference.")
19 20 21 22 23 24 25 26 27 28 29 30
parser.add_argument(
    "--num_conv_layers", default=2, type=int, help="Convolution layer number.")
parser.add_argument(
    "--num_rnn_layers", default=3, type=int, help="RNN layer number.")
parser.add_argument(
    "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.")
parser.add_argument(
    "--use_gpu", default=True, type=bool, help="Use gpu or not.")
args = parser.parse_args()


def remove_duplicate_and_blank(id_list, blank_id):
X
Xinghai Sun 已提交
31 32 33 34 35
    """
    Postprocessing for max-ctc-decoder.
    - remove consecutive duplicate tokens.
    - remove blanks.
    """
36 37
    # remove consecutive duplicate tokens
    id_list = [x[0] for x in groupby(id_list)]
X
Xinghai Sun 已提交
38
    # remove blanks
39 40 41 42
    return [id for id in id_list if id != blank_id]


def max_infer():
X
Xinghai Sun 已提交
43 44 45
    """
    Max-ctc-decoding for DeepSpeech2.
    """
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    # create network config
    _, vocab_list = audio_data_utils.get_vocabulary()
    dict_size = len(vocab_list)
    audio_data = paddle.layer.data(
        name="audio_spectrogram",
        height=161,
        width=1000,
        type=paddle.data_type.dense_vector(161000))
    text_data = paddle.layer.data(
        name="transcript_text",
        type=paddle.data_type.integer_value_sequence(dict_size))
    _, max_id = deep_speech2(
        audio_data=audio_data,
        text_data=text_data,
        dict_size=dict_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_size=args.rnn_layer_size)

    # load parameters
    parameters = paddle.parameters.Parameters.from_tar(
        gzip.open("params.tar.gz"))

    # prepare infer data
    feeding = {
        "audio_spectrogram": 0,
        "transcript_text": 1,
    }
    test_batch_reader = audio_data_utils.padding_batch_reader(
        paddle.batch(
            audio_data_utils.reader_creator(
                manifest_path="./libri.manifest.test", sort_by_duration=False),
            batch_size=args.num_samples),
        padding=[-1, 1000])
    infer_data = test_batch_reader().next()

X
Xinghai Sun 已提交
82
    # run max-ctc-decoding
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    max_id_results = paddle.infer(
        output_layer=max_id,
        parameters=parameters,
        input=infer_data,
        field=['id'])

    # postprocess
    instance_length = len(max_id_results) / args.num_samples
    instance_list = [
        max_id_results[i:i + instance_length]
        for i in xrange(0, args.num_samples)
    ]
    for i, instance in enumerate(instance_list):
        id_list = remove_duplicate_and_blank(instance, dict_size)
        output_transcript = ''.join([vocab_list[id] for id in id_list])
        target_transcript = ''.join([vocab_list[id] for id in infer_data[i][1]])
        print("Target Transcript: %s \nOutput Transcript: %s \n" %
              (target_transcript, output_transcript))


def main():
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    max_infer()


if __name__ == '__main__':
    main()