tune.py 8.3 KB
Newer Older
1
"""Beam search parameters tuning for DeepSpeech2 model."""
Y
Yibing Liu 已提交
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5

6
import sys
7
import numpy as np
8
import argparse
X
Xinghai Sun 已提交
9
import functools
Y
Yibing Liu 已提交
10
import paddle.v2 as paddle
11
import _init_paths
Y
Yibing Liu 已提交
12
from data_utils.data import DataGenerator
13
from model_utils.model import DeepSpeech2Model
14 15
from utils.error_rate import wer
from utils.utility import add_arguments, print_arguments
16

17
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
18 19
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
add_arg('num_batches',      int,    -1,    "# of batches tuning on. "
                                           "Default -1, on whole dev set.")
add_arg('batch_size',       int,    256,   "# of samples per batch.")
add_arg('trainer_count',    int,    8,     "# of Trainers (CPUs or GPUs).")
add_arg('beam_size',        int,    500,   "Beam search width.")
add_arg('num_proc_bsearch', int,    12,    "# of CPUs for beam search.")
add_arg('num_conv_layers',  int,    2,     "# of convolution layers.")
add_arg('num_rnn_layers',   int,    3,     "# of recurrent layers.")
add_arg('rnn_layer_size',   int,    2048,  "# of recurrent cells per layer.")
add_arg('num_alphas',       int,    45,    "# of alpha candidates for tuning.")
add_arg('num_betas',        int,    8,     "# of beta candidates for tuning.")
add_arg('alpha_from',       float,  1.0,   "Where alpha starts tuning from.")
add_arg('alpha_to',         float,  3.2,   "Where alpha ends tuning with.")
add_arg('beta_from',        float,  0.1,   "Where beta starts tuning from.")
add_arg('beta_to',          float,  0.45,  "Where beta ends tuning with.")
add_arg('cutoff_prob',      float,  1.0,   "Cutoff probability for pruning.")
add_arg('cutoff_top_n',     int,    40,    "Cutoff number for pruning.")
add_arg('output_fig',       bool,   True,  "Output error rate figure or not.")
add_arg('use_gru',          bool,   False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu',          bool,   True,  "Use GPU or not.")
add_arg('share_rnn_weights',bool,   True,  "Share input-hidden weights across "
                                           "bi-directional RNNs. Not for GRU.")
42
add_arg('tune_manifest',    str,
43
        'data/librispeech/manifest.dev-clean',
44 45
        "Filepath of manifest to tune.")
add_arg('mean_std_path',    str,
46
        'data/librispeech/mean_std.npz',
47 48
        "Filepath of normalizer's mean & std.")
add_arg('vocab_path',       str,
49
        'data/librispeech/vocab.txt',
50
        "Filepath of vocabulary.")
51
add_arg('lang_model_path',  str,
52
        'models/lm/common_crawl_00.prune01111.trie.klm',
53
        "Filepath for language model.")
54
add_arg('model_path',       str,
55
        './checkpoints/libri/params.latest.tar.gz',
56 57
        "If None, the training starts from scratch, "
        "otherwise, it resumes from the pre-trained model.")
58 59 60 61 62 63 64 65
add_arg('error_rate_type',  str,
        'wer',
        "Error rate type for evaluation.",
        choices=['wer', 'cer'])
add_arg('specgram_type',    str,
        'linear',
        "Audio feature type. Options: linear, mfcc.",
        choices=['linear', 'mfcc'])
66
# yapf: disable
X
Xinghai Sun 已提交
67
args = parser.parse_args()
68

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
def plot_error_surface(params_grid, err_ave, fig_name):
    import matplotlib.pyplot as plt
    import mpl_toolkits.mplot3d as Axes3D
    fig = plt.figure()
    ax = Axes3D(fig)
    alphas = [ param[0] for param in params_grid ]
    betas = [ param[1] for param in params_grid]
    ALPHAS = np.reshape(alphas, (args.num_alphas, args.num_betas))
    BETAS = np.reshape(betas, (args.num_alphas, args.num_betas))
    ERR_AVE = np.reshape(err_ave, (args.num_alphas, args.num_betas))
    ax.plot_surface(ALPHAS, BETAS, WERS,
               rstride=1, cstride=1, alpha=0.8, cmap='rainbow')
    ax.set_xlabel('alpha')
    ax.set_ylabel('beta')
    z_label = 'WER' if args.error_rate_type == 'wer' else 'CER'
    ax.set_zlabel(z_label)
    plt.savefig(fig_name)
86

87
def tune():
Y
Yibing Liu 已提交
88
    """Tune parameters alpha and beta on one minibatch."""
89 90 91 92
    if not args.num_alphas >= 0:
        raise ValueError("num_alphas must be non-negative!")
    if not args.num_betas >= 0:
        raise ValueError("num_betas must be non-negative!")
93 94

    data_generator = DataGenerator(
95 96
        vocab_filepath=args.vocab_path,
        mean_std_filepath=args.mean_std_path,
Y
Yibing Liu 已提交
97
        augmentation_config='{}',
98
        specgram_type=args.specgram_type,
99
        num_threads=1)
Y
Yibing Liu 已提交
100
    batch_reader = data_generator.batch_reader_creator(
101
        manifest_path=args.tune_manifest,
102
        batch_size=args.batch_size,
Y
Yibing Liu 已提交
103 104
        sortagrad=False,
        shuffle_method=None)
105 106 107 108
    tune_data = batch_reader().next()
    target_transcripts = [
        ''.join([data_generator.vocab_list[token] for token in transcript])
        for _, transcript in tune_data
109 110
    ]

111 112 113 114 115
    ds2_model = DeepSpeech2Model(
        vocab_size=data_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
116
        use_gru=args.use_gru,
117
        pretrained_model_path=args.model_path,
118
        share_rnn_weights=args.share_rnn_weights)
119

120 121 122 123
    # decoders only accept string encoded in utf-8
    vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]

    error_rate_func = cer if args.error_rate_type == 'cer' else wer
124 125 126 127 128 129
    # create grid for search
    cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
    cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
    params_grid = [(alpha, beta) for alpha in cand_alphas
                   for beta in cand_betas]

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    err_sum = [0.0 for i in xrange(len(params_grid))]
    err_ave = [0.0 for i in xrange(len(params_grid))]
    num_ins, cur_batch = 0, 0
    ## incremental tuning parameters over multiple batches
    for infer_data in batch_reader():
        if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
            break

        target_transcripts = [
            ''.join([data_generator.vocab_list[token] for token in transcript])
            for _, transcript in infer_data
        ]

        num_ins += len(target_transcripts)
        # grid search
        for index, (alpha, beta) in enumerate(params_grid):
            result_transcripts = ds2_model.infer_batch(
                infer_data=infer_data,
                decoding_method='ctc_beam_search',
                beam_alpha=alpha,
                beam_beta=beta,
                beam_size=args.beam_size,
                cutoff_prob=args.cutoff_prob,
                cutoff_top_n=args.cutoff_top_n,
                vocab_list=vocab_list,
                language_model_path=args.lang_model_path,
                num_processes=args.num_proc_bsearch)

            for target, result in zip(target_transcripts, result_transcripts):
                err_sum[index] += error_rate_func(target, result)
            err_ave[index] = err_sum[index] / num_ins
            # print("alpha = %f, beta = %f, WER = %f" %
            #      (alpha, beta, err_ave[index]))
            if index % 10 == 0:
                sys.stdout.write('.')
                sys.stdout.flush()

        # output on-line tuning result at the the end of current batch
        err_ave_min = min(err_ave)
        min_index = err_ave.index(err_ave_min)
        print("\nBatch %d, opt.(alpha, beta) = (%f, %f), min. error_rate = %f"
                %(cur_batch, params_grid[min_index][0],
               params_grid[min_index][1], err_ave_min))
        cur_batch += 1

    # output WER/CER at every point
    print("\nerror rate at each point:\n")
    for index in xrange(len(params_grid)):
        print("(%f, %f), error_rate = %f"
              % (params_grid[index][0], params_grid[index][1], err_ave[index]))

    err_ave_min = min(err_ave)
    min_index = err_ave.index(err_ave_min)
    print("\nTuning on %d batches, opt. (alpha, beta) = (%f, %f)"
            % (args.num_batches, params_grid[min_index][0],
              params_grid[min_index][1]))

    if args.output_fig == True:
        fig_name = ("error_surface_alphas_%d_betas_%d" %
                   (args.num_alphas, args.num_betas))
        plot_error_surface(params_grid, err_ave, fig_name)
        ds2_model.logger.info("output figure %s" % fig_name)
192

193
    ds2_model.logger.info("finish inference")
194 195

def main():
196
    print_arguments(args)
197
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
198 199 200 201 202
    tune()


if __name__ == '__main__':
    main()