diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index 328d67a1197634e5f02ad0689056196a8904fc06..5dc44a86c72637124348b05b9e0bceb4801f5270 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -15,10 +15,10 @@ import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--num_samples", - default=100, + "--batch_size", + default=128, type=int, - help="Number of samples for parameters tuning. (default: %(default)s)") + help="Minibatch size for parameters tuning. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -51,7 +51,7 @@ parser.add_argument( help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--num_processes_beam_search", - default=multiprocessing.cpu_count() // 2, + default=multiprocessing.cpu_count(), type=int, help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( @@ -130,7 +130,12 @@ args = parser.parse_args() def tune(): - """Tune parameters alpha and beta on one minibatch.""" + """Tune parameters alpha and beta for the CTC beam search decoder + incrementally. The optimal parameters up to now would be output real time + at the end of each minibatch data, until all the development data is + taken into account. And the tuning process can be terminated at any time + as long as the two parameters get stable. + """ if not args.num_alphas >= 0: raise ValueError("num_alphas must be non-negative!") if not args.num_betas >= 0: @@ -144,14 +149,9 @@ def tune(): num_threads=args.num_threads_data) batch_reader = data_generator.batch_reader_creator( manifest_path=args.tune_manifest_path, - batch_size=args.num_samples, + batch_size=args.batch_size, sortagrad=False, shuffle_method=None) - tune_data = batch_reader().next() - target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in tune_data - ] ds2_model = DeepSpeech2Model( vocab_size=data_generator.vocab_size, @@ -166,24 +166,44 @@ def tune(): params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] - ## tune parameters in loop - for alpha, beta in params_grid: - result_transcripts = ds2_model.infer_batch( - infer_data=tune_data, - decode_method='beam_search', - beam_alpha=alpha, - beam_beta=beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, - language_model_path=args.language_model_path, - num_processes=args.num_processes_beam_search) - wer_sum, num_ins = 0.0, 0 - for target, result in zip(target_transcripts, result_transcripts): - wer_sum += wer(target, result) - num_ins += 1 - print("alpha = %f\tbeta = %f\tWER = %f" % - (alpha, beta, wer_sum / num_ins)) + wer_sum = [0.0 for i in xrange(len(params_grid))] + ave_wer = [0.0 for i in xrange(len(params_grid))] + num_ins = 0 + num_batches = 0 + ## incremental tuning parameters over multiple batches + for infer_data in batch_reader(): + target_transcripts = [ + ''.join([data_generator.vocab_list[token] for token in transcript]) + for _, transcript in infer_data + ] + + num_ins += len(target_transcripts) + # grid search + for index, (alpha, beta) in enumerate(params_grid): + result_transcripts = ds2_model.infer_batch( + infer_data=infer_data, + decode_method='beam_search', + beam_alpha=alpha, + beam_beta=beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=args.num_processes_beam_search) + + for target, result in zip(target_transcripts, result_transcripts): + wer_sum[index] += wer(target, result) + ave_wer[index] = wer_sum[index] / num_ins + print("alpha = %f, beta = %f, WER = %f" % + (alpha, beta, ave_wer[index])) + + # output on-line tuning result at the the end of current batch + ave_wer_min = min(ave_wer) + min_index = ave_wer.index(ave_wer_min) + print("Finish batch %d, optimal (alpha, beta, WER) = (%f, %f, %f)\n" % + (num_batches, params_grid[min_index][0], + params_grid[min_index][1], ave_wer_min)) + num_batches += 1 def main():