tune.py 7.3 KB
Newer Older
1
"""Beam search parameters tuning for DeepSpeech2 model."""
Y
Yibing Liu 已提交
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5

6
import sys
7
import numpy as np
8
import argparse
X
Xinghai Sun 已提交
9
import functools
Y
Yibing Liu 已提交
10
import paddle.v2 as paddle
11
import _init_paths
Y
Yibing Liu 已提交
12
from data_utils.data import DataGenerator
13
from model_utils.model import DeepSpeech2Model
14 15
from utils.error_rate import wer
from utils.utility import add_arguments, print_arguments
16

17
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
18 19
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
Y
Yibing Liu 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
add_arg('num_batches',      int,    -1,     "# of batches tuning on. "
                                            "Default -1, on whole dev set.")
add_arg('batch_size',       int,    256,    "# of samples per batch.")
add_arg('trainer_count',    int,    8,      "# of Trainers (CPUs or GPUs).")
add_arg('beam_size',        int,    500,    "Beam search width.")
add_arg('num_proc_bsearch', int,    12,     "# of CPUs for beam search.")
add_arg('num_conv_layers',  int,    2,      "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,      "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,   "# of recurrent cells per layer.")
add_arg('num_alphas',       int,    45,     "# of alpha candidates for tuning.")
add_arg('num_betas',        int,    8,      "# of beta candidates for tuning.")
add_arg('alpha_from',       float,  1.0,    "Where alpha starts tuning from.")
add_arg('alpha_to',         float,  3.2,    "Where alpha ends tuning with.")
add_arg('beta_from',        float,  0.1,    "Where beta starts tuning from.")
add_arg('beta_to',          float,  0.45,   "Where beta ends tuning with.")
add_arg('cutoff_prob',      float,  1.0,    "Cutoff probability for pruning.")
add_arg('cutoff_top_n',     int,    40,     "Cutoff number for pruning.")
add_arg('use_gru',          bool,   False,  "Use GRUs instead of simple RNNs.")
add_arg('use_gpu',          bool,   True,   "Use GPU or not.")
add_arg('share_rnn_weights',bool,   True,   "Share input-hidden weights across "
                                            "bi-directional RNNs. Not for GRU.")
41
add_arg('tune_manifest',    str,
42
        'data/librispeech/manifest.dev-clean',
43 44
        "Filepath of manifest to tune.")
add_arg('mean_std_path',    str,
45
        'data/librispeech/mean_std.npz',
46 47
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
48
        'data/librispeech/vocab.txt',
49
        "Filepath of vocabulary.")
50
add_arg('lang_model_path',  str,
51
        'models/lm/common_crawl_00.prune01111.trie.klm',
52
        "Filepath for language model.")
53
add_arg('model_path',       str,
54
        './checkpoints/libri/params.latest.tar.gz',
55 56
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
57 58 59 60 61 62 63 64
add_arg('error_rate_type',  str,
        'wer',
        "Error rate type for evaluation.",
        choices=['wer', 'cer'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
65
# yapf: disable
X
Xinghai Sun 已提交
66
args = parser.parse_args()
67

68

69
def tune():
70
    """Tune parameters alpha and beta incrementally."""
71 72 73 74
    if not args.num_alphas >= 0:
        raise ValueError("num_alphas must be non-negative!")
    if not args.num_betas >= 0:
        raise ValueError("num_betas must be non-negative!")
75 76

    data_generator = DataGenerator(
77 78
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
Y
Yibing Liu 已提交
79
        augmentation_config='{}',
80
        specgram_type=args.specgram_type,
81
        num_threads=1)
Y
Yibing Liu 已提交
82
    batch_reader = data_generator.batch_reader_creator(
83
        manifest_path=args.tune_manifest,
84
        batch_size=args.batch_size,
Y
Yibing Liu 已提交
85 86
        sortagrad=False,
        shuffle_method=None)
87 88 89 90
    tune_data = batch_reader().next()
    target_transcripts = [
        ''.join([data_generator.vocab_list[token] for token in transcript])
        for _, transcript in tune_data
91 92
    ]

93 94 95 96 97
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
98
        use_gru=args.use_gru,
99
        pretrained_model_path=args.model_path,
100
        share_rnn_weights=args.share_rnn_weights)
101

102 103 104 105
    # decoders only accept string encoded in utf-8
    vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]

    error_rate_func = cer if args.error_rate_type == 'cer' else wer
106 107 108 109 110 111
    # create grid for search
    cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
    cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
    params_grid = [(alpha, beta) for alpha in cand_alphas
                   for beta in cand_betas]

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    err_sum = [0.0 for i in xrange(len(params_grid))]
    err_ave = [0.0 for i in xrange(len(params_grid))]
    num_ins, cur_batch = 0, 0
    ## incremental tuning parameters over multiple batches
    for infer_data in batch_reader():
        if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
            break

        target_transcripts = [
            ''.join([data_generator.vocab_list[token] for token in transcript])
            for _, transcript in infer_data
        ]

        num_ins += len(target_transcripts)
        # grid search
        for index, (alpha, beta) in enumerate(params_grid):
            result_transcripts = ds2_model.infer_batch(
                infer_data=infer_data,
                decoding_method='ctc_beam_search',
                beam_alpha=alpha,
                beam_beta=beta,
                beam_size=args.beam_size,
                cutoff_prob=args.cutoff_prob,
                cutoff_top_n=args.cutoff_top_n,
                vocab_list=vocab_list,
                language_model_path=args.lang_model_path,
                num_processes=args.num_proc_bsearch)

            for target, result in zip(target_transcripts, result_transcripts):
                err_sum[index] += error_rate_func(target, result)
            err_ave[index] = err_sum[index] / num_ins
143
            if index % 2 == 0:
144 145 146
                sys.stdout.write('.')
                sys.stdout.flush()

Y
Yibing Liu 已提交
147
        # output on-line tuning result at the end of current batch
148 149
        err_ave_min = min(err_ave)
        min_index = err_ave.index(err_ave_min)
150 151 152 153 154
        print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
              " min [%s] = %f" %(cur_batch, num_ins,
              "%.3f" % params_grid[min_index][0],
              "%.3f" % params_grid[min_index][1],
              args.error_rate_type, err_ave_min))
155 156
        cur_batch += 1

Y
Yibing Liu 已提交
157
    # output WER/CER at every (alpha, beta)
158
    print("\nFinal %s:\n" % args.error_rate_type)
159
    for index in xrange(len(params_grid)):
160 161 162
        print("(alpha, beta) = (%s, %s), [%s] = %f"
             % ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
             args.error_rate_type, err_ave[index]))
163 164 165

    err_ave_min = min(err_ave)
    min_index = err_ave.index(err_ave_min)
166 167 168
    print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
            % (args.num_batches, "%.3f" % params_grid[min_index][0],
              "%.3f" % params_grid[min_index][1]))
169

170
    ds2_model.logger.info("finish inference")
171

172

173
def main():
174
    print_arguments(args)
175
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
176 177 178 179 180
    tune()


if __name__ == '__main__':
    main()