tune.py 7.4 KB
Newer Older
1
"""Beam search parameters tuning for DeepSpeech2 model."""
Y
Yibing Liu 已提交
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5

6
import sys
7
import numpy as np
8
import argparse
X
Xinghai Sun 已提交
9
import functools
Y
Yibing Liu 已提交
10
import paddle.v2 as paddle
11
import _init_paths
Y
Yibing Liu 已提交
12
from data_utils.data import DataGenerator
13
from model_utils.model import DeepSpeech2Model
14 15
from utils.error_rate import wer
from utils.utility import add_arguments, print_arguments
16

17
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
18 19
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
add_arg('num_batches',      int,    -1,    "# of batches tuning on. "
                                           "Default -1, on whole dev set.")
add_arg('batch_size',       int,    256,   "# of samples per batch.")
add_arg('trainer_count',    int,    8,     "# of Trainers (CPUs or GPUs).")
add_arg('beam_size',        int,    500,   "Beam search width.")
add_arg('num_proc_bsearch', int,    12,    "# of CPUs for beam search.")
add_arg('num_conv_layers',  int,    2,     "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,     "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,  "# of recurrent cells per layer.")
add_arg('num_alphas',       int,    45,    "# of alpha candidates for tuning.")
add_arg('num_betas',        int,    8,     "# of beta candidates for tuning.")
add_arg('alpha_from',       float,  1.0,   "Where alpha starts tuning from.")
add_arg('alpha_to',         float,  3.2,   "Where alpha ends tuning with.")
add_arg('beta_from',        float,  0.1,   "Where beta starts tuning from.")
add_arg('beta_to',          float,  0.45,  "Where beta ends tuning with.")
add_arg('cutoff_prob',      float,  1.0,   "Cutoff probability for pruning.")
add_arg('cutoff_top_n',     int,    40,    "Cutoff number for pruning.")
add_arg('use_gru',          bool,   False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu',          bool,   True,  "Use GPU or not.")
add_arg('share_rnn_weights',bool,   True,  "Share input-hidden weights across "
                                           "bi-directional RNNs. Not for GRU.")
41
add_arg('tune_manifest',    str,
42
        'data/librispeech/manifest.dev-clean',
43 44
        "Filepath of manifest to tune.")
add_arg('mean_std_path',    str,
45
        'data/librispeech/mean_std.npz',
46 47
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
48
        'data/librispeech/vocab.txt',
49
        "Filepath of vocabulary.")
50
add_arg('lang_model_path',  str,
51
        'models/lm/common_crawl_00.prune01111.trie.klm',
52
        "Filepath for language model.")
53
add_arg('model_path',       str,
54
        './checkpoints/libri/params.latest.tar.gz',
55 56
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
57 58 59 60 61 62 63 64
add_arg('error_rate_type',  str,
        'wer',
        "Error rate type for evaluation.",
        choices=['wer', 'cer'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
65
# yapf: disable
X
Xinghai Sun 已提交
66
args = parser.parse_args()
67

68

69
def tune():
70
    """Tune parameters alpha and beta incrementally."""
71 72 73 74
    if not args.num_alphas >= 0:
        raise ValueError("num_alphas must be non-negative!")
    if not args.num_betas >= 0:
        raise ValueError("num_betas must be non-negative!")
75 76

    data_generator = DataGenerator(
77 78
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
Y
Yibing Liu 已提交
79
        augmentation_config='{}',
80
        specgram_type=args.specgram_type,
81
        num_threads=1)
Y
Yibing Liu 已提交
82
    batch_reader = data_generator.batch_reader_creator(
83
        manifest_path=args.tune_manifest,
84
        batch_size=args.batch_size,
Y
Yibing Liu 已提交
85 86
        sortagrad=False,
        shuffle_method=None)
87 88 89 90
    tune_data = batch_reader().next()
    target_transcripts = [
        ''.join([data_generator.vocab_list[token] for token in transcript])
        for _, transcript in tune_data
91 92
    ]

93 94 95 96 97
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
98
        use_gru=args.use_gru,
99
        pretrained_model_path=args.model_path,
100
        share_rnn_weights=args.share_rnn_weights)
101

102 103 104 105
    # decoders only accept string encoded in utf-8
    vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]

    error_rate_func = cer if args.error_rate_type == 'cer' else wer
106 107 108 109 110 111
    # create grid for search
    cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
    cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
    params_grid = [(alpha, beta) for alpha in cand_alphas
                   for beta in cand_betas]

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    err_sum = [0.0 for i in xrange(len(params_grid))]
    err_ave = [0.0 for i in xrange(len(params_grid))]
    num_ins, cur_batch = 0, 0
    ## incremental tuning parameters over multiple batches
    for infer_data in batch_reader():
        if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
            break

        target_transcripts = [
            ''.join([data_generator.vocab_list[token] for token in transcript])
            for _, transcript in infer_data
        ]

        num_ins += len(target_transcripts)
        # grid search
        for index, (alpha, beta) in enumerate(params_grid):
            result_transcripts = ds2_model.infer_batch(
                infer_data=infer_data,
                decoding_method='ctc_beam_search',
                beam_alpha=alpha,
                beam_beta=beta,
                beam_size=args.beam_size,
                cutoff_prob=args.cutoff_prob,
                cutoff_top_n=args.cutoff_top_n,
                vocab_list=vocab_list,
                language_model_path=args.lang_model_path,
                num_processes=args.num_proc_bsearch)

            for target, result in zip(target_transcripts, result_transcripts):
                err_sum[index] += error_rate_func(target, result)
            err_ave[index] = err_sum[index] / num_ins
            # print("alpha = %f, beta = %f, WER = %f" %
            #      (alpha, beta, err_ave[index]))
145
            if index % 2 == 0:
146 147 148 149 150 151
                sys.stdout.write('.')
                sys.stdout.flush()

        # output on-line tuning result at the the end of current batch
        err_ave_min = min(err_ave)
        min_index = err_ave.index(err_ave_min)
152 153 154 155 156
        print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
              " min [%s] = %f" %(cur_batch, num_ins,
              "%.3f" % params_grid[min_index][0],
              "%.3f" % params_grid[min_index][1],
              args.error_rate_type, err_ave_min))
157 158 159
        cur_batch += 1

    # output WER/CER at every point
160
    print("\nFinal %s:\n" % args.error_rate_type)
161
    for index in xrange(len(params_grid)):
162 163 164
        print("(alpha, beta) = (%s, %s), [%s] = %f"
             % ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
             args.error_rate_type, err_ave[index]))
165 166 167

    err_ave_min = min(err_ave)
    min_index = err_ave.index(err_ave_min)
168 169 170
    print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
            % (args.num_batches, "%.3f" % params_grid[min_index][0],
              "%.3f" % params_grid[min_index][1]))
171

172
    ds2_model.logger.info("finish inference")
173

174

175
def main():
176
    print_arguments(args)
177
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
178 179 180 181 182
    tune()


if __name__ == '__main__':
    main()