infer.py 5.9 KB
Newer Older
1
"""Inferer for DeepSpeech2 model."""
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
Xinghai Sun 已提交
5

6
import argparse
7
import distutils.util
8
import multiprocessing
9 10
import paddle.v2 as paddle
from data_utils.data import DataGenerator
11
from model import DeepSpeech2Model
Y
yangyaming 已提交
12
from error_rate import wer, cer
13
import utils
14

15
parser = argparse.ArgumentParser(description=__doc__)
16
parser.add_argument(
X
Xinghai Sun 已提交
17
    "--num_samples",
Y
Yibing Liu 已提交
18
    default=10,
X
Xinghai Sun 已提交
19
    type=int,
20
    help="Number of samples for inference. (default: %(default)s)")
21
parser.add_argument(
22 23 24 25 26 27 28 29 30 31 32
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
parser.add_argument(
    "--rnn_layer_size",
33
    default=2048,
34 35
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
36 37 38 39 40 41 42
parser.add_argument(
    "--share_rnn_weights",
    default=True,
    type=distutils.util.strtobool,
    help="Whether to share input-hidden weights between forword and backward "
    "directional simple RNNs. Only available when use_gru=False. "
    "(default: %(default)s)")
X
Xinghai Sun 已提交
43 44
parser.add_argument(
    "--use_gru",
45 46
    default=False,
    type=distutils.util.strtobool,
X
Xinghai Sun 已提交
47
    help="Use GRU or simple RNN. (default: %(default)s)")
48 49 50 51 52
parser.add_argument(
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
53 54
parser.add_argument(
    "--num_threads_data",
55
    default=1,
56 57
    type=int,
    help="Number of cpu threads for preprocessing data. (default: %(default)s)")
Y
Yibing Liu 已提交
58 59
parser.add_argument(
    "--num_processes_beam_search",
60
    default=multiprocessing.cpu_count() // 2,
Y
Yibing Liu 已提交
61 62
    type=int,
    help="Number of cpu processes for beam search. (default: %(default)s)")
63 64 65 66 67 68
parser.add_argument(
    "--specgram_type",
    default='linear',
    type=str,
    help="Feature type of audio data: 'linear' (power spectrum)"
    " or 'mfcc'. (default: %(default)s)")
69 70 71 72 73
parser.add_argument(
    "--trainer_count",
    default=8,
    type=int,
    help="Trainer number. (default: %(default)s)")
74
parser.add_argument(
75 76
    "--mean_std_filepath",
    default='mean_std.npz',
77 78
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
79
parser.add_argument(
80
    "--decode_manifest_path",
Y
Yibing Liu 已提交
81
    default='datasets/manifest.test',
82 83
    type=str,
    help="Manifest path for decoding. (default: %(default)s)")
84
parser.add_argument(
85
    "--model_filepath",
Y
Yibing Liu 已提交
86
    default='checkpoints/params.latest.tar.gz',
87 88
    type=str,
    help="Model filepath. (default: %(default)s)")
89 90
parser.add_argument(
    "--vocab_filepath",
91
    default='datasets/vocab/eng_vocab.txt',
92 93
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
Y
Yibing Liu 已提交
94 95
parser.add_argument(
    "--decode_method",
Y
Yibing Liu 已提交
96
    default='beam_search',
Y
Yibing Liu 已提交
97
    type=str,
98 99
    help="Method for ctc decoding: best_path or beam_search. "
    "(default: %(default)s)")
Y
Yibing Liu 已提交
100 101
parser.add_argument(
    "--beam_size",
102
    default=500,
Y
Yibing Liu 已提交
103 104
    type=int,
    help="Width for beam search decoding. (default: %(default)d)")
Y
Yibing Liu 已提交
105 106
parser.add_argument(
    "--language_model_path",
Y
Yibing Liu 已提交
107
    default="lm/data/common_crawl_00.prune01111.trie.klm",
Y
Yibing Liu 已提交
108
    type=str,
Y
Yibing Liu 已提交
109
    help="Path for language model. (default: %(default)s)")
Y
Yibing Liu 已提交
110 111
parser.add_argument(
    "--alpha",
112
    default=0.36,
Y
Yibing Liu 已提交
113 114 115 116
    type=float,
    help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
    "--beta",
117
    default=0.25,
Y
Yibing Liu 已提交
118 119
    type=float,
    help="Parameter associated with word count. (default: %(default)f)")
120 121 122 123 124 125
parser.add_argument(
    "--cutoff_prob",
    default=0.99,
    type=float,
    help="The cutoff probability of pruning"
    "in beam search. (default: %(default)f)")
Y
yangyaming 已提交
126 127 128 129 130
parser.add_argument(
    "--error_rate_type",
    default='wer',
    choices=['wer', 'cer'],
    type=str,
Y
yangyaming 已提交
131 132
    help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
    "for character error rate. "
Y
yangyaming 已提交
133
    "(default: %(default)s)")
134 135 136
args = parser.parse_args()


137
def infer():
Y
Yibing Liu 已提交
138
    """Inference for DeepSpeech2."""
139
    data_generator = DataGenerator(
140
        vocab_filepath=args.vocab_filepath,
141
        mean_std_filepath=args.mean_std_filepath,
142
        augmentation_config='{}',
143
        specgram_type=args.specgram_type,
144
        num_threads=args.num_threads_data)
145
    batch_reader = data_generator.batch_reader_creator(
146 147
        manifest_path=args.decode_manifest_path,
        batch_size=args.num_samples,
Y
Yibing Liu 已提交
148
        min_batch_size=1,
149
        sortagrad=False,
150
        shuffle_method=None)
151
    infer_data = batch_reader().next()
152

153 154 155 156 157
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
158
        use_gru=args.use_gru,
159 160
        pretrained_model_path=args.model_filepath,
        share_rnn_weights=args.share_rnn_weights)
161 162 163 164 165 166 167 168 169 170
    result_transcripts = ds2_model.infer_batch(
        infer_data=infer_data,
        decode_method=args.decode_method,
        beam_alpha=args.alpha,
        beam_beta=args.beta,
        beam_size=args.beam_size,
        cutoff_prob=args.cutoff_prob,
        vocab_list=data_generator.vocab_list,
        language_model_path=args.language_model_path,
        num_processes=args.num_processes_beam_search)
171

Y
yangyaming 已提交
172
    error_rate_func = cer if args.error_rate_type == 'cer' else wer
173 174 175
    target_transcripts = [
        ''.join([data_generator.vocab_list[token] for token in transcript])
        for _, transcript in infer_data
Y
Yibing Liu 已提交
176
    ]
177 178 179
    for target, result in zip(target_transcripts, result_transcripts):
        print("\nTarget Transcription: %s\nOutput Transcription: %s" %
              (target, result))
Y
yangyaming 已提交
180 181
        print("Current error rate [%s] = %f" %
              (args.error_rate_type, error_rate_func(target, result)))
182 183 184


def main():
185
    utils.print_arguments(args)
186
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
187
    infer()
188 189 190 191


if __name__ == '__main__':
    main()