evaluate.py 6.7 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4
"""Evaluation for DeepSpeech2 model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5 6 7 8

import distutils.util
import argparse
import gzip
Y
Yibing Liu 已提交
9
import paddle.v2 as paddle
Y
Yibing Liu 已提交
10
from data_utils.data import DataGenerator
11 12
from model import deep_speech2
from decoder import *
Y
Yibing Liu 已提交
13
from lm.lm_scorer import LmScorer
14 15
from error_rate import wer

Y
Yibing Liu 已提交
16
parser = argparse.ArgumentParser(description=__doc__)
17
parser.add_argument(
Y
Yibing Liu 已提交
18
    "--batch_size",
19 20
    default=100,
    type=int,
Y
Yibing Liu 已提交
21
    help="Minibatch size for evaluation. (default: %(default)s)")
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
parser.add_argument(
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
parser.add_argument(
    "--rnn_layer_size",
    default=512,
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
Y
Yibing Liu 已提交
42 43 44 45 46 47 48 49 50 51
parser.add_argument(
    "--num_threads_data",
    default=multiprocessing.cpu_count(),
    type=int,
    help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
    "--num_processes_beam_search",
    default=multiprocessing.cpu_count(),
    type=int,
    help="Number of cpu processes for beam search. (default: %(default)s)")
Y
Yibing Liu 已提交
52 53 54 55 56
parser.add_argument(
    "--mean_std_filepath",
    default='mean_std.npz',
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
57 58
parser.add_argument(
    "--decode_method",
Y
Yibing Liu 已提交
59
    default='beam_search',
60
    type=str,
Y
Yibing Liu 已提交
61 62
    help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
)
63 64
parser.add_argument(
    "--language_model_path",
Y
Yibing Liu 已提交
65
    default="lm/data/1Billion.klm",
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    type=str,
    help="Path for language model. (default: %(default)s)")
parser.add_argument(
    "--alpha",
    default=0.26,
    type=float,
    help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
    "--beta",
    default=0.1,
    type=float,
    help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
    "--cutoff_prob",
    default=0.99,
    type=float,
    help="The cutoff probability of pruning"
    "in beam search. (default: %(default)f)")
parser.add_argument(
    "--beam_size",
    default=500,
    type=int,
    help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
    "--decode_manifest_path",
Y
Yibing Liu 已提交
91
    default='datasets/manifest.test',
92 93 94 95
    type=str,
    help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
    "--model_filepath",
Y
Yibing Liu 已提交
96
    default='checkpoints/params.latest.tar.gz',
97 98 99 100
    type=str,
    help="Model filepath. (default: %(default)s)")
parser.add_argument(
    "--vocab_filepath",
Y
Yibing Liu 已提交
101
    default='datasets/vocab/eng_vocab.txt',
102 103 104 105 106 107
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()


def evaluate():
Y
Yibing Liu 已提交
108
    """Evaluate on whole test data for DeepSpeech2."""
109 110 111
    # initialize data generator
    data_generator = DataGenerator(
        vocab_filepath=args.vocab_filepath,
Y
Yibing Liu 已提交
112
        mean_std_filepath=args.mean_std_filepath,
Y
Yibing Liu 已提交
113 114
        augmentation_config='{}',
        num_threads=args.num_threads_data)
115 116

    # create network config
Y
Yibing Liu 已提交
117 118 119
    # paddle.data_type.dense_array is used for variable batch input.
    # The size 161 * 161 is only an placeholder value and the real shape
    # of input batch data will be induced during training.
120
    audio_data = paddle.layer.data(
Y
Yibing Liu 已提交
121
        name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
122 123
    text_data = paddle.layer.data(
        name="transcript_text",
Y
Yibing Liu 已提交
124
        type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
125 126 127
    output_probs = deep_speech2(
        audio_data=audio_data,
        text_data=text_data,
Y
Yibing Liu 已提交
128
        dict_size=data_generator.vocab_size,
129 130 131 132 133 134 135 136 137 138
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_size=args.rnn_layer_size,
        is_inference=True)

    # load parameters
    parameters = paddle.parameters.Parameters.from_tar(
        gzip.open(args.model_filepath))

    # prepare infer data
Y
Yibing Liu 已提交
139
    batch_reader = data_generator.batch_reader_creator(
140
        manifest_path=args.decode_manifest_path,
Y
Yibing Liu 已提交
141
        batch_size=args.batch_size,
Y
Yibing Liu 已提交
142 143
        sortagrad=False,
        shuffle_method=None)
144 145 146 147 148 149

    # define inferer
    inferer = paddle.inference.Inference(
        output_layer=output_probs, parameters=parameters)

    # initialize external scorer for beam search decoding
Y
Yibing Liu 已提交
150 151
    if args.decode_method == 'beam_search':
        ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
152 153

    wer_counter, wer_sum = 0, 0.0
Y
Yibing Liu 已提交
154
    for infer_data in batch_reader():
155 156
        # run inference
        infer_results = inferer.infer(input=infer_data)
Y
Yibing Liu 已提交
157
        num_steps = len(infer_results) // len(infer_data)
158 159 160 161
        probs_split = [
            infer_results[i * num_steps:(i + 1) * num_steps]
            for i in xrange(0, len(infer_data))
        ]
Y
Yibing Liu 已提交
162 163 164 165 166 167
        # target transcription
        target_transcription = [
            ''.join([
                data_generator.vocab_list[index] for index in infer_data[i][1]
            ]) for i, probs in enumerate(probs_split)
        ]
168 169 170 171
        # decode and print
        # best path decode
        if args.decode_method == "best_path":
            for i, probs in enumerate(probs_split):
Y
Yibing Liu 已提交
172
                output_transcription = ctc_best_path_decoder(
Y
Yibing Liu 已提交
173
                    probs_seq=probs, vocabulary=data_generator.vocab_list)
Y
Yibing Liu 已提交
174
                wer_sum += wer(target_transcription[i], output_transcription)
175
                wer_counter += 1
Y
Yibing Liu 已提交
176
        # beam search decode
177
        elif args.decode_method == "beam_search":
Y
Yibing Liu 已提交
178 179
            # beam search using multiple processes
            beam_search_results = ctc_beam_search_decoder_batch(
180
                probs_split=probs_split,
Y
Yibing Liu 已提交
181
                vocabulary=data_generator.vocab_list,
182
                beam_size=args.beam_size,
Y
Yibing Liu 已提交
183
                blank_id=len(data_generator.vocab_list),
Y
Yibing Liu 已提交
184
                num_processes=args.num_processes_beam_search,
185 186
                ext_scoring_func=ext_scorer,
                cutoff_prob=args.cutoff_prob, )
Y
Yibing Liu 已提交
187 188 189
            for i, beam_search_result in enumerate(beam_search_results):
                wer_sum += wer(target_transcription[i],
                               beam_search_result[0][1])
190 191
                wer_counter += 1
        else:
Y
Yibing Liu 已提交
192 193
            raise ValueError("Decoding method [%s] is not supported." %
                             decode_method)
194 195 196 197 198 199 200 201 202 203 204

    print("Final WER = %f" % (wer_sum / wer_counter))


def main():
    paddle.init(use_gpu=args.use_gpu, trainer_count=1)
    evaluate()


if __name__ == '__main__':
    main()