diff --git a/deep_speech_2/decoder.py b/deep_speech_2/decoder.py index 8f2e0508de79fea30ebc30230e948b15923bdf24..61ead25c8d46f8a362b8d72d88dd80aac5824088 100644 --- a/deep_speech_2/decoder.py +++ b/deep_speech_2/decoder.py @@ -9,8 +9,9 @@ from math import log import multiprocessing -def ctc_best_path_decoder(probs_seq, vocabulary): - """Best path decoder, also called argmax decoder or greedy decoder. +def ctc_greedy_decoder(probs_seq, vocabulary): + """CTC greedy (best path) decoder. + Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -45,10 +46,12 @@ def ctc_beam_search_decoder(probs_seq, cutoff_prob=1.0, ext_scoring_func=None, nproc=False): - """Beam search decoder for CTC-trained network. It utilizes beam search - to approximately select top best decoding labels and returning results - in the descending order. The implementation is based on Prefix - Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is + """CTC Beam search decoder. + + It utilizes beam search to approximately select top best decoding + labels and returning results in the descending order. + The implementation is based on Prefix Beam Search + (https://arxiv.org/abs/1408.2873), and the unclear part is redesigned. Two important modifications: 1) in the iterative computation of probabilities, the assignment operation is changed to accumulation for one prefix may comes from different paths; 2) the if condition "if l^+ not diff --git a/deep_speech_2/demo_server.py b/deep_speech_2/demo_server.py index b000e35e91c20ec925fe1cd52a3901ed7ee9519f..7cbee1fd44f517cc4d6e0602eda01163737dd93f 100644 --- a/deep_speech_2/demo_server.py +++ b/deep_speech_2/demo_server.py @@ -3,123 +3,63 @@ import os import time import random import argparse -import distutils.util +import functools from time import gmtime, strftime import SocketServer import struct import wave import paddle.v2 as paddle -from utils import print_arguments from data_utils.data import DataGenerator from model import DeepSpeech2Model from data_utils.utils import read_manifest +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--host_ip", - default="localhost", - type=str, - help="Server IP address. (default: %(default)s)") -parser.add_argument( - "--host_port", - default=8086, - type=int, - help="Server Port. (default: %(default)s)") -parser.add_argument( - "--speech_save_dir", - default="demo_cache", - type=str, - help="Directory for saving demo speech. (default: %(default)s)") -parser.add_argument( - "--vocab_filepath", - default='datasets/vocab/eng_vocab.txt', - type=str, - help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--mean_std_filepath", - default='mean_std.npz', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--warmup_manifest_path", - default='datasets/manifest.test', - type=str, - help="Manifest path for warmup test. (default: %(default)s)") -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--num_conv_layers", - default=2, - type=int, - help="Convolution layer number. (default: %(default)s)") -parser.add_argument( - "--num_rnn_layers", - default=3, - type=int, - help="RNN layer number. (default: %(default)s)") -parser.add_argument( - "--rnn_layer_size", - default=2048, - type=int, - help="RNN layer cell number. (default: %(default)s)") -parser.add_argument( - "--share_rnn_weights", - default=True, - type=distutils.util.strtobool, - help="Whether to share input-hidden weights between forword and backward " - "directional simple RNNs. Only available when use_gru=False. " - "(default: %(default)s)") -parser.add_argument( - "--use_gru", - default=False, - type=distutils.util.strtobool, - help="Use GRU or simple RNN. (default: %(default)s)") -parser.add_argument( - "--use_gpu", - default=True, - type=distutils.util.strtobool, - help="Use gpu or not. (default: %(default)s)") -parser.add_argument( - "--model_filepath", - default='checkpoints/params.latest.tar.gz', - type=str, - help="Model filepath. (default: %(default)s)") -parser.add_argument( - "--decode_method", - default='beam_search', - type=str, - help="Method for ctc decoding: best_path or beam_search. " - "(default: %(default)s)") -parser.add_argument( - "--beam_size", - default=100, - type=int, - help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--language_model_path", - default="lm/data/common_crawl_00.prune01111.trie.klm", - type=str, - help="Path for language model. (default: %(default)s)") -parser.add_argument( - "--alpha", - default=0.36, - type=float, - help="Parameter associated with language model. (default: %(default)f)") -parser.add_argument( - "--beta", - default=0.25, - type=float, - help="Parameter associated with word count. (default: %(default)f)") -parser.add_argument( - "--cutoff_prob", - default=0.99, - type=float, - help="The cutoff probability of pruning" - "in beam search. (default: %(default)f)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('host_port', int, 8086, "Server's IP port.") +add_arg('beam_size', int, 500, "Beam search width.") +add_arg('num_conv_layers', int, 2, "# of convolution layers.") +add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") +add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") +add_arg('alpha', float, 0.36, "Coef of LM for beam search.") +add_arg('beta', float, 0.25, "Coef of WC for beam search.") +add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") +add_arg('use_gpu', bool, True, "Use GPU or not.") +add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " + "bi-directional RNNs. Not for GRU.") +add_arg('host_ip', str, + 'localhost', + "Server's IP address.") +add_arg('speech_save_dir', str, + 'demo_cache', + "Directory to save demo audios.") +add_arg('warmup_manifest', str, + 'datasets/manifest.test', + "Filepath of manifest to warm up.") +add_arg('mean_std_path', str, + 'mean_std.npz', + "Filepath of normalizer's mean & std.") +add_arg('vocab_path', str, + 'datasets/vocab/eng_vocab.txt', + "Filepath of vocabulary.") +add_arg('model_path', str, + './checkpoints/params.latest.tar.gz', + "If None, the training starts from scratch, " + "otherwise, it resumes from the pre-trained model.") +add_arg('lang_model_path', str, + 'lm/data/common_crawl_00.prune01111.trie.klm', + "Filepath for language model.") +add_arg('decoding_method', str, + 'ctc_beam_search', + "Decoding method. Options: ctc_beam_search, ctc_greedy", + choices = ['ctc_beam_search', 'ctc_greedy']) +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +# yapf: disable args = parser.parse_args() @@ -200,8 +140,8 @@ def start_server(): """Start the ASR server""" # prepare data generator data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=1) @@ -212,7 +152,7 @@ def start_server(): num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, use_gru=args.use_gru, - pretrained_model_path=args.model_filepath, + pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) # prepare ASR inference handler @@ -220,13 +160,13 @@ def start_server(): feature = data_generator.process_utterance(filename, "") result_transcript = ds2_model.infer_batch( infer_data=[feature], - decode_method=args.decode_method, + decoding_method=args.decoding_method, beam_alpha=args.alpha, beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, vocab_list=data_generator.vocab_list, - language_model_path=args.language_model_path, + language_model_path=args.lang_model_path, num_processes=1) return result_transcript[0] @@ -235,7 +175,7 @@ def start_server(): print('Warming up ...') warm_up_test( audio_process_handler=file_to_transcript, - manifest_path=args.warmup_manifest_path, + manifest_path=args.warmup_manifest, num_test_cases=3) print('-----------------------------------------------------------') diff --git a/deep_speech_2/evaluate.py b/deep_speech_2/evaluate.py index 8dd169b6c2a41a1ad749324e6cba60bff98d951b..1cc307dad3e611fe73cd7786976bfaca6a7c8227 100644 --- a/deep_speech_2/evaluate.py +++ b/deep_speech_2/evaluate.py @@ -3,147 +3,74 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import distutils.util import argparse -import multiprocessing +import functools import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model from error_rate import wer, cer -import utils +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--batch_size", - default=128, - type=int, - help="Minibatch size for evaluation. (default: %(default)s)") -parser.add_argument( - "--trainer_count", - default=8, - type=int, - help="Trainer number. (default: %(default)s)") -parser.add_argument( - "--num_conv_layers", - default=2, - type=int, - help="Convolution layer number. (default: %(default)s)") -parser.add_argument( - "--num_rnn_layers", - default=3, - type=int, - help="RNN layer number. (default: %(default)s)") -parser.add_argument( - "--rnn_layer_size", - default=2048, - type=int, - help="RNN layer cell number. (default: %(default)s)") -parser.add_argument( - "--share_rnn_weights", - default=True, - type=distutils.util.strtobool, - help="Whether to share input-hidden weights between forword and backward " - "directional simple RNNs. Only available when use_gru=False. " - "(default: %(default)s)") -parser.add_argument( - "--use_gru", - default=False, - type=distutils.util.strtobool, - help="Use GRU or simple RNN. (default: %(default)s)") -parser.add_argument( - "--use_gpu", - default=True, - type=distutils.util.strtobool, - help="Use gpu or not. (default: %(default)s)") -parser.add_argument( - "--num_threads_data", - default=multiprocessing.cpu_count() // 2, - type=int, - help="Number of cpu threads for preprocessing data. (default: %(default)s)") -parser.add_argument( - "--num_processes_beam_search", - default=multiprocessing.cpu_count() // 2, - type=int, - help="Number of cpu processes for beam search. (default: %(default)s)") -parser.add_argument( - "--mean_std_filepath", - default='mean_std.npz', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--decode_method", - default='beam_search', - type=str, - help="Method for ctc decoding, best_path or beam_search. " - "(default: %(default)s)") -parser.add_argument( - "--language_model_path", - default="lm/data/common_crawl_00.prune01111.trie.klm", - type=str, - help="Path for language model. (default: %(default)s)") -parser.add_argument( - "--alpha", - default=0.36, - type=float, - help="Parameter associated with language model. (default: %(default)f)") -parser.add_argument( - "--beta", - default=0.25, - type=float, - help="Parameter associated with word count. (default: %(default)f)") -parser.add_argument( - "--cutoff_prob", - default=0.99, - type=float, - help="The cutoff probability of pruning" - "in beam search. (default: %(default)f)") -parser.add_argument( - "--beam_size", - default=500, - type=int, - help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--decode_manifest_path", - default='datasets/manifest.test', - type=str, - help="Manifest path for decoding. (default: %(default)s)") -parser.add_argument( - "--model_filepath", - default='checkpoints/params.latest.tar.gz', - type=str, - help="Model filepath. (default: %(default)s)") -parser.add_argument( - "--vocab_filepath", - default='datasets/vocab/eng_vocab.txt', - type=str, - help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--error_rate_type", - default='wer', - choices=['wer', 'cer'], - type=str, - help="Error rate type for evaluation. 'wer' for word error rate and 'cer' " - "for character error rate. " - "(default: %(default)s)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 128, "Minibatch size.") +add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).") +add_arg('beam_size', int, 500, "Beam search width.") +add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.") +add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.") +add_arg('num_conv_layers', int, 2, "# of convolution layers.") +add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") +add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") +add_arg('alpha', float, 0.36, "Coef of LM for beam search.") +add_arg('beta', float, 0.25, "Coef of WC for beam search.") +add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") +add_arg('use_gpu', bool, True, "Use GPU or not.") +add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " + "bi-directional RNNs. Not for GRU.") +add_arg('test_manifest', str, + 'datasets/manifest.test', + "Filepath of manifest to evaluate.") +add_arg('mean_std_path', str, + 'mean_std.npz', + "Filepath of normalizer's mean & std.") +add_arg('vocab_path', str, + 'datasets/vocab/eng_vocab.txt', + "Filepath of vocabulary.") +add_arg('model_path', str, + './checkpoints/params.latest.tar.gz', + "If None, the training starts from scratch, " + "otherwise, it resumes from the pre-trained model.") +add_arg('lang_model_path', str, + 'lm/data/common_crawl_00.prune01111.trie.klm', + "Filepath for language model.") +add_arg('decoding_method', str, + 'ctc_beam_search', + "Decoding method. Options: ctc_beam_search, ctc_greedy", + choices = ['ctc_beam_search', 'ctc_greedy']) +add_arg('error_rate_type', str, + 'wer', + "Error rate type for evaluation.", + choices=['wer', 'cer']) +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +# yapf: disable args = parser.parse_args() def evaluate(): """Evaluate on whole test data for DeepSpeech2.""" data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_threads_data) + num_threads=args.num_proc_data) batch_reader = data_generator.batch_reader_creator( - manifest_path=args.decode_manifest_path, + manifest_path=args.test_manifest, batch_size=args.batch_size, min_batch_size=1, sortagrad=False, @@ -155,7 +82,7 @@ def evaluate(): num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, use_gru=args.use_gru, - pretrained_model_path=args.model_filepath, + pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) error_rate_func = cer if args.error_rate_type == 'cer' else wer @@ -163,14 +90,14 @@ def evaluate(): for infer_data in batch_reader(): result_transcripts = ds2_model.infer_batch( infer_data=infer_data, - decode_method=args.decode_method, + decoding_method=args.decoding_method, beam_alpha=args.alpha, beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, vocab_list=data_generator.vocab_list, - language_model_path=args.language_model_path, - num_processes=args.num_processes_beam_search) + language_model_path=args.lang_model_path, + num_processes=args.num_proc_bsearch) target_transcripts = [ ''.join([data_generator.vocab_list[token] for token in transcript]) for _, transcript in infer_data @@ -185,7 +112,7 @@ def evaluate(): def main(): - utils.print_arguments(args) + print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) evaluate() diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 0c52ffc831b3349dacc5453bc21dc9a13e6471c8..3fd835b467f0d838efa05410be898c0a75aac24d 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -4,146 +4,72 @@ from __future__ import division from __future__ import print_function import argparse -import distutils.util -import multiprocessing +import functools import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model from error_rate import wer, cer -import utils +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--num_samples", - default=10, - type=int, - help="Number of samples for inference. (default: %(default)s)") -parser.add_argument( - "--num_conv_layers", - default=2, - type=int, - help="Convolution layer number. (default: %(default)s)") -parser.add_argument( - "--num_rnn_layers", - default=3, - type=int, - help="RNN layer number. (default: %(default)s)") -parser.add_argument( - "--rnn_layer_size", - default=2048, - type=int, - help="RNN layer cell number. (default: %(default)s)") -parser.add_argument( - "--share_rnn_weights", - default=True, - type=distutils.util.strtobool, - help="Whether to share input-hidden weights between forword and backward " - "directional simple RNNs. Only available when use_gru=False. " - "(default: %(default)s)") -parser.add_argument( - "--use_gru", - default=False, - type=distutils.util.strtobool, - help="Use GRU or simple RNN. (default: %(default)s)") -parser.add_argument( - "--use_gpu", - default=True, - type=distutils.util.strtobool, - help="Use gpu or not. (default: %(default)s)") -parser.add_argument( - "--num_threads_data", - default=1, - type=int, - help="Number of cpu threads for preprocessing data. (default: %(default)s)") -parser.add_argument( - "--num_processes_beam_search", - default=multiprocessing.cpu_count() // 2, - type=int, - help="Number of cpu processes for beam search. (default: %(default)s)") -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--trainer_count", - default=8, - type=int, - help="Trainer number. (default: %(default)s)") -parser.add_argument( - "--mean_std_filepath", - default='mean_std.npz', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--decode_manifest_path", - default='datasets/manifest.test', - type=str, - help="Manifest path for decoding. (default: %(default)s)") -parser.add_argument( - "--model_filepath", - default='checkpoints/params.latest.tar.gz', - type=str, - help="Model filepath. (default: %(default)s)") -parser.add_argument( - "--vocab_filepath", - default='datasets/vocab/eng_vocab.txt', - type=str, - help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--decode_method", - default='beam_search', - type=str, - help="Method for ctc decoding: best_path or beam_search. " - "(default: %(default)s)") -parser.add_argument( - "--beam_size", - default=500, - type=int, - help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--language_model_path", - default="lm/data/common_crawl_00.prune01111.trie.klm", - type=str, - help="Path for language model. (default: %(default)s)") -parser.add_argument( - "--alpha", - default=0.36, - type=float, - help="Parameter associated with language model. (default: %(default)f)") -parser.add_argument( - "--beta", - default=0.25, - type=float, - help="Parameter associated with word count. (default: %(default)f)") -parser.add_argument( - "--cutoff_prob", - default=0.99, - type=float, - help="The cutoff probability of pruning" - "in beam search. (default: %(default)f)") -parser.add_argument( - "--error_rate_type", - default='wer', - choices=['wer', 'cer'], - type=str, - help="Error rate type for evaluation. 'wer' for word error rate and 'cer' " - "for character error rate. " - "(default: %(default)s)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('num_samples', int, 10, "# of samples to infer.") +add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).") +add_arg('beam_size', int, 500, "Beam search width.") +add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.") +add_arg('num_conv_layers', int, 2, "# of convolution layers.") +add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") +add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") +add_arg('alpha', float, 0.36, "Coef of LM for beam search.") +add_arg('beta', float, 0.25, "Coef of WC for beam search.") +add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") +add_arg('use_gpu', bool, True, "Use GPU or not.") +add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " + "bi-directional RNNs. Not for GRU.") +add_arg('infer_manifest', str, + 'datasets/manifest.dev', + "Filepath of manifest to infer.") +add_arg('mean_std_path', str, + 'mean_std.npz', + "Filepath of normalizer's mean & std.") +add_arg('vocab_path', str, + 'datasets/vocab/eng_vocab.txt', + "Filepath of vocabulary.") +add_arg('lang_model_path', str, + 'lm/data/common_crawl_00.prune01111.trie.klm', + "Filepath for language model.") +add_arg('model_path', str, + './checkpoints/params.latest.tar.gz', + "If None, the training starts from scratch, " + "otherwise, it resumes from the pre-trained model.") +add_arg('decoding_method', str, + 'ctc_beam_search', + "Decoding method. Options: ctc_beam_search, ctc_greedy", + choices = ['ctc_beam_search', 'ctc_greedy']) +add_arg('error_rate_type', str, + 'wer', + "Error rate type for evaluation.", + choices=['wer', 'cer']) +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +# yapf: disable args = parser.parse_args() def infer(): """Inference for DeepSpeech2.""" data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_threads_data) + num_threads=1) batch_reader = data_generator.batch_reader_creator( - manifest_path=args.decode_manifest_path, + manifest_path=args.infer_manifest, batch_size=args.num_samples, min_batch_size=1, sortagrad=False, @@ -156,18 +82,18 @@ def infer(): num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, use_gru=args.use_gru, - pretrained_model_path=args.model_filepath, + pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) result_transcripts = ds2_model.infer_batch( infer_data=infer_data, - decode_method=args.decode_method, + decoding_method=args.decoding_method, beam_alpha=args.alpha, beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, vocab_list=data_generator.vocab_list, - language_model_path=args.language_model_path, - num_processes=args.num_processes_beam_search) + language_model_path=args.lang_model_path, + num_processes=args.num_proc_bsearch) error_rate_func = cer if args.error_rate_type == 'cer' else wer target_transcripts = [ @@ -182,7 +108,7 @@ def infer(): def main(): - utils.print_arguments(args) + print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) infer() diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py index 0234ed2d4c901f36ebfc16b317f5355cd57796e0..06f69290682226dffc601711d81f45242e23538d 100644 --- a/deep_speech_2/model.py +++ b/deep_speech_2/model.py @@ -146,7 +146,7 @@ class DeepSpeech2Model(object): # run inference return self._loss_inferer.infer(input=infer_data) - def infer_batch(self, infer_data, decode_method, beam_alpha, beam_beta, + def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, beam_size, cutoff_prob, vocab_list, language_model_path, num_processes): """Model inference. Infer the transcription for a batch of speech @@ -156,9 +156,9 @@ class DeepSpeech2Model(object): consisting of a tuple of audio features and transcription text (empty string). :type infer_data: list - :param decode_method: Decoding method name, 'best_path' or - 'beam search'. - :param decode_method: string + :param decoding_method: Decoding method name, 'ctc_greedy' or + 'ctc_beam_search'. + :param decoding_method: string :param beam_alpha: Parameter associated with language model. :type beam_alpha: float :param beam_beta: Parameter associated with word count. @@ -190,13 +190,13 @@ class DeepSpeech2Model(object): ] # run decoder results = [] - if decode_method == "best_path": + if decoding_method == "ctc_greedy": # best path decode for i, probs in enumerate(probs_split): - output_transcription = ctc_best_path_decoder( + output_transcription = ctc_greedy_decoder( probs_seq=probs, vocabulary=vocab_list) results.append(output_transcription) - elif decode_method == "beam_search": + elif decoding_method == "ctc_beam_search": # initialize external scorer if self._ext_scorer == None: self._ext_scorer = LmScorer(beam_alpha, beam_beta, @@ -205,7 +205,6 @@ class DeepSpeech2Model(object): else: self._ext_scorer.reset_params(beam_alpha, beam_beta) assert self._loaded_lm_path == language_model_path - # beam search decode beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, @@ -219,7 +218,7 @@ class DeepSpeech2Model(object): results = [result[0][1] for result in beam_search_results] else: raise ValueError("Decoding method [%s] is not supported." % - decode_method) + decoding_method) return results def _create_parameters(self, model_path=None): diff --git a/deep_speech_2/tests/test_decoders.py b/deep_speech_2/tests/test_decoders.py index 99d8a8289d93574c58ced50923716c39cfb96558..fa43879b8741c4f9f62d9a2b648e105fc5d51d37 100644 --- a/deep_speech_2/tests/test_decoders.py +++ b/deep_speech_2/tests/test_decoders.py @@ -49,16 +49,16 @@ class TestDecoders(unittest.TestCase): 0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306, 0.05294827, 0.22298418 ]] - self.best_path_result = ["ac'bdc", "b'da"] + self.greedy_result = ["ac'bdc", "b'da"] self.beam_search_result = ['acdc', "b'a"] - def test_best_path_decoder_1(self): - bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list) - self.assertEqual(bst_result, self.best_path_result[0]) + def test_greedy_decoder_1(self): + bst_result = ctc_greedy_decoder(self.probs_seq1, self.vocab_list) + self.assertEqual(bst_result, self.greedy_result[0]) - def test_best_path_decoder_2(self): - bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list) - self.assertEqual(bst_result, self.best_path_result[1]) + def test_greedy_decoder_2(self): + bst_result = ctc_greedy_decoder(self.probs_seq2, self.vocab_list) + self.assertEqual(bst_result, self.greedy_result[1]) def test_beam_search_decoder_1(self): beam_result = ctc_beam_search_decoder( diff --git a/deep_speech_2/tools/build_vocab.py b/deep_speech_2/tools/build_vocab.py index 618f2498537ba9d085a0ec3a60852f591bb0ff3e..ac600302679320f2fcfbee7645ad83c2442b47d5 100644 --- a/deep_speech_2/tools/build_vocab.py +++ b/deep_speech_2/tools/build_vocab.py @@ -7,32 +7,29 @@ from __future__ import division from __future__ import print_function import argparse +import functools import codecs import json from collections import Counter import os.path import _init_paths from data_utils import utils +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--manifest_paths", - type=str, - help="Manifest paths for building vocabulary." - "You can provide multiple manifest files.", - nargs='+', - required=True) -parser.add_argument( - "--count_threshold", - default=0, - type=int, - help="Characters whose counts are below the threshold will be truncated. " - "(default: %(default)i)") -parser.add_argument( - "--vocab_path", - default='datasets/vocab/zh_vocab.txt', - type=str, - help="File path to write the vocabulary. (default: %(default)s)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('count_threshold', int, 0, "Truncation threshold for char counts.") +add_arg('vocab_path', str, + 'datasets/vocab/zh_vocab.txt', + "Filepath to write the vocabulary.") +add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) +# yapf: disable args = parser.parse_args() @@ -44,6 +41,8 @@ def count_manifest(counter, manifest_path): def main(): + print_arguments(args) + counter = Counter() for manifest_path in args.manifest_paths: count_manifest(counter, manifest_path) diff --git a/deep_speech_2/tools/compute_mean_std.py b/deep_speech_2/tools/compute_mean_std.py index da49eb4c056700e6c4da5e740c2bbcee84fa3154..9f7bf06cedf532458d1d704f4099a4f23e931be5 100644 --- a/deep_speech_2/tools/compute_mean_std.py +++ b/deep_speech_2/tools/compute_mean_std.py @@ -4,48 +4,35 @@ from __future__ import division from __future__ import print_function import argparse +import functools import _init_paths from data_utils.normalizer import FeatureNormalizer from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.featurizer.audio_featurizer import AudioFeaturizer - -parser = argparse.ArgumentParser( - description='Computing mean and stddev for feature normalizer.') -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--manifest_path", - default='datasets/manifest.train', - type=str, - help="Manifest path for computing normalizer's mean and stddev." - "(default: %(default)s)") -parser.add_argument( - "--num_samples", - default=2000, - type=int, - help="Number of samples for computing mean and stddev. " - "(default: %(default)s)") -parser.add_argument( - "--augmentation_config", - default='{}', - type=str, - help="Augmentation configuration in json-format. " - "(default: %(default)s)") -parser.add_argument( - "--output_file", - default='mean_std.npz', - type=str, - help="Filepath to write mean and std to (.npz)." - "(default: %(default)s)") +from utils import add_arguments, print_arguments + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('num_samples', int, 2000, "# of samples to for statistics.") +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +add_arg('manifest_path', str, + 'datasets/manifest.train', + "Filepath of manifest to compute normalizer's mean and stddev.") +add_arg('output_path', str, + 'mean_std.npz', + "Filepath of write mean and stddev to (.npz).") +# yapf: disable args = parser.parse_args() def main(): - augmentation_pipeline = AugmentationPipeline(args.augmentation_config) + print_arguments(args) + + augmentation_pipeline = AugmentationPipeline('{}') audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) def augment_and_featurize(audio_segment): @@ -57,7 +44,7 @@ def main(): manifest_path=args.manifest_path, featurize_func=augment_and_featurize, num_samples=args.num_samples) - normalizer.write_to_file(args.output_file) + normalizer.write_to_file(args.output_path) if __name__ == '__main__': diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index d055341f10c82f3cec38867e2db36cfaaabe0a79..7cef7539b35b805030976303ea901e6d8081386e 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -4,174 +4,91 @@ from __future__ import division from __future__ import print_function import argparse -import distutils.util -import multiprocessing +import functools import paddle.v2 as paddle from model import DeepSpeech2Model from data_utils.data import DataGenerator -import utils +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--batch_size", default=256, type=int, help="Minibatch size.") -parser.add_argument( - "--num_passes", - default=200, - type=int, - help="Training pass number. (default: %(default)s)") -parser.add_argument( - "--num_iterations_print", - default=100, - type=int, - help="Number of iterations for every train cost printing. " - "(default: %(default)s)") -parser.add_argument( - "--num_conv_layers", - default=2, - type=int, - help="Convolution layer number. (default: %(default)s)") -parser.add_argument( - "--num_rnn_layers", - default=3, - type=int, - help="RNN layer number. (default: %(default)s)") -parser.add_argument( - "--rnn_layer_size", - default=2048, - type=int, - help="RNN layer cell number. (default: %(default)s)") -parser.add_argument( - "--share_rnn_weights", - default=True, - type=distutils.util.strtobool, - help="Whether to share input-hidden weights between forword and backward " - "directional simple RNNs. Only available when use_gru=False. " - "(default: %(default)s)") -parser.add_argument( - "--use_gru", - default=False, - type=distutils.util.strtobool, - help="Use GRU or simple RNN. (default: %(default)s)") -parser.add_argument( - "--adam_learning_rate", - default=5e-4, - type=float, - help="Learning rate for ADAM Optimizer. (default: %(default)s)") -parser.add_argument( - "--use_gpu", - default=True, - type=distutils.util.strtobool, - help="Use gpu or not. (default: %(default)s)") -parser.add_argument( - "--use_sortagrad", - default=True, - type=distutils.util.strtobool, - help="Use sortagrad or not. (default: %(default)s)") -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--max_duration", - default=27.0, - type=float, - help="Audios with duration larger than this will be discarded. " - "(default: %(default)s)") -parser.add_argument( - "--min_duration", - default=0.0, - type=float, - help="Audios with duration smaller than this will be discarded. " - "(default: %(default)s)") -parser.add_argument( - "--shuffle_method", - default='batch_shuffle_clipped', - type=str, - help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " - "'batch_shuffle_batch'. (default: %(default)s)") -parser.add_argument( - "--trainer_count", - default=8, - type=int, - help="Trainer number. (default: %(default)s)") -parser.add_argument( - "--num_threads_data", - default=multiprocessing.cpu_count() // 2, - type=int, - help="Number of cpu threads for preprocessing data. (default: %(default)s)") -parser.add_argument( - "--mean_std_filepath", - default='mean_std.npz', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--train_manifest_path", - default='datasets/manifest.train', - type=str, - help="Manifest path for training. (default: %(default)s)") -parser.add_argument( - "--dev_manifest_path", - default='datasets/manifest.dev', - type=str, - help="Manifest path for validation. (default: %(default)s)") -parser.add_argument( - "--vocab_filepath", - default='datasets/vocab/eng_vocab.txt', - type=str, - help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--init_model_path", - default=None, - type=str, - help="If set None, the training will start from scratch. " - "Otherwise, the training will resume from " - "the existing model of this path. (default: %(default)s)") -parser.add_argument( - "--output_model_dir", - default="./checkpoints", - type=str, - help="Directory for saving models. (default: %(default)s)") -parser.add_argument( - "--augmentation_config", - default=open('conf/augmentation.config', 'r').read(), - type=str, - help="Augmentation configuration in json-format. " - "(default: %(default)s)") -parser.add_argument( - "--is_local", - default=True, - type=distutils.util.strtobool, - help="Set to false if running with pserver in paddlecloud. " - "(default: %(default)s)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 256, "Minibatch size.") +add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).") +add_arg('num_passes', int, 200, "# of training epochs.") +add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.") +add_arg('num_conv_layers', int, 2, "# of convolution layers.") +add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") +add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") +add_arg('num_iter_print', int, 100, "Every # iterations for printing " + "train cost.") +add_arg('learning_rate', float, 5e-4, "Learning rate.") +add_arg('max_duration', float, 27.0, "Longest audio duration allowed.") +add_arg('min_duration', float, 0.0, "Shortest audio duration allowed.") +add_arg('use_sortagrad', bool, True, "Use SortaGrad or not.") +add_arg('use_gpu', bool, True, "Use GPU or not.") +add_arg('is_local', bool, True, "Use pserver or not.") +add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") +add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " + "bi-directional RNNs. Not for GRU.") +add_arg('train_manifest', str, + 'datasets/manifest.train', + "Filepath of train manifest.") +add_arg('dev_manifest', str, + 'datasets/manifest.dev', + "Filepath of validation manifest.") +add_arg('mean_std_path', str, + 'mean_std.npz', + "Filepath of normalizer's mean & std.") +add_arg('vocab_path', str, + 'datasets/vocab/eng_vocab.txt', + "Filepath of vocabulary.") +add_arg('init_model_path', str, + None, + "If None, the training starts from scratch, " + "otherwise, it resumes from the pre-trained model.") +add_arg('output_model_dir', str, + "./checkpoints", + "Directory for saving checkpoints.") +add_arg('augment_conf_path',str, + 'conf/augmentation.config', + "Filepath of augmentation configuration file (json-format).") +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +add_arg('shuffle_method', str, + 'batch_shuffle_clipped', + "Shuffle method.", + choices=['instance_shuffle', 'batch_shuffle', 'batch_shuffle_clipped']) +# yapf: disable args = parser.parse_args() def train(): """DeepSpeech2 training.""" train_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, - augmentation_config=args.augmentation_config, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, + augmentation_config=open(args.augment_conf_path, 'r').read(), max_duration=args.max_duration, min_duration=args.min_duration, specgram_type=args.specgram_type, - num_threads=args.num_threads_data) + num_threads=args.num_proc_data) dev_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, augmentation_config="{}", specgram_type=args.specgram_type, - num_threads=args.num_threads_data) + num_threads=args.num_proc_data) train_batch_reader = train_generator.batch_reader_creator( - manifest_path=args.train_manifest_path, + manifest_path=args.train_manifest, batch_size=args.batch_size, min_batch_size=args.trainer_count, sortagrad=args.use_sortagrad if args.init_model_path is None else False, shuffle_method=args.shuffle_method) dev_batch_reader = dev_generator.batch_reader_creator( - manifest_path=args.dev_manifest_path, + manifest_path=args.dev_manifest, batch_size=args.batch_size, min_batch_size=1, # must be 1, but will have errors. sortagrad=False, @@ -184,21 +101,21 @@ def train(): rnn_layer_size=args.rnn_layer_size, use_gru=args.use_gru, pretrained_model_path=args.init_model_path, - share_rnn_weights=args.share_rnn_weights) + share_rnn_weights=args.share_weights) ds2_model.train( train_batch_reader=train_batch_reader, dev_batch_reader=dev_batch_reader, feeding_dict=train_generator.feeding, - learning_rate=args.adam_learning_rate, + learning_rate=args.learning_rate, gradient_clipping=400, num_passes=args.num_passes, - num_iterations_print=args.num_iterations_print, + num_iterations_print=args.num_iter_print, output_model_dir=args.output_model_dir, is_local=args.is_local) def main(): - utils.print_arguments(args) + print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train() diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index d8001339eef1f51bb221238a647b2c4857a790d2..eab00cfdb3ff54725767373df6a84ff4e4bc505e 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -1,143 +1,63 @@ -"""Parameters tuning for DeepSpeech2 model.""" +"""Beam search parameters tuning for DeepSpeech2 model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -import distutils.util import argparse -import multiprocessing +import functools import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model from error_rate import wer -import utils +from utils import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--num_samples", - default=100, - type=int, - help="Number of samples for parameters tuning. (default: %(default)s)") -parser.add_argument( - "--num_conv_layers", - default=2, - type=int, - help="Convolution layer number. (default: %(default)s)") -parser.add_argument( - "--num_rnn_layers", - default=3, - type=int, - help="RNN layer number. (default: %(default)s)") -parser.add_argument( - "--rnn_layer_size", - default=2048, - type=int, - help="RNN layer cell number. (default: %(default)s)") -parser.add_argument( - "--share_rnn_weights", - default=True, - type=distutils.util.strtobool, - help="Whether to share input-hidden weights between forword and backward " - "directional simple RNNs. Only available when use_gru=False. " - "(default: %(default)s)") -parser.add_argument( - "--use_gru", - default=False, - type=distutils.util.strtobool, - help="Use GRU or simple RNN. (default: %(default)s)") -parser.add_argument( - "--use_gpu", - default=True, - type=distutils.util.strtobool, - help="Use gpu or not. (default: %(default)s)") -parser.add_argument( - "--trainer_count", - default=8, - type=int, - help="Trainer number. (default: %(default)s)") -parser.add_argument( - "--num_threads_data", - default=1, - type=int, - help="Number of cpu threads for preprocessing data. (default: %(default)s)") -parser.add_argument( - "--num_processes_beam_search", - default=multiprocessing.cpu_count() // 2, - type=int, - help="Number of cpu processes for beam search. (default: %(default)s)") -parser.add_argument( - "--specgram_type", - default='linear', - type=str, - help="Feature type of audio data: 'linear' (power spectrum)" - " or 'mfcc'. (default: %(default)s)") -parser.add_argument( - "--mean_std_filepath", - default='mean_std.npz', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--tune_manifest_path", - default='datasets/manifest.dev', - type=str, - help="Manifest path for tuning. (default: %(default)s)") -parser.add_argument( - "--model_filepath", - default='checkpoints/params.latest.tar.gz', - type=str, - help="Model filepath. (default: %(default)s)") -parser.add_argument( - "--vocab_filepath", - default='datasets/vocab/eng_vocab.txt', - type=str, - help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--beam_size", - default=500, - type=int, - help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--language_model_path", - default="lm/data/common_crawl_00.prune01111.trie.klm", - type=str, - help="Path for language model. (default: %(default)s)") -parser.add_argument( - "--alpha_from", - default=0.1, - type=float, - help="Where alpha starts from. (default: %(default)f)") -parser.add_argument( - "--num_alphas", - default=14, - type=int, - help="Number of candidate alphas. (default: %(default)d)") -parser.add_argument( - "--alpha_to", - default=0.36, - type=float, - help="Where alpha ends with. (default: %(default)f)") -parser.add_argument( - "--beta_from", - default=0.05, - type=float, - help="Where beta starts from. (default: %(default)f)") -parser.add_argument( - "--num_betas", - default=20, - type=float, - help="Number of candidate betas. (default: %(default)d)") -parser.add_argument( - "--beta_to", - default=1.0, - type=float, - help="Where beta ends with. (default: %(default)f)") -parser.add_argument( - "--cutoff_prob", - default=0.99, - type=float, - help="The cutoff probability of pruning" - "in beam search. (default: %(default)f)") +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('num_samples', int, 100, "# of samples to infer.") +add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).") +add_arg('beam_size', int, 500, "Beam search width.") +add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.") +add_arg('num_conv_layers', int, 2, "# of convolution layers.") +add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") +add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") +add_arg('num_alphas', int, 14, "# of alpha candidates for tuning.") +add_arg('num_betas', int, 20, "# of beta candidates for tuning.") +add_arg('alpha_from', float, 0.1, "Where alpha starts tuning from.") +add_arg('alpha_to', float, 0.36, "Where alpha ends tuning with.") +add_arg('beta_from', float, 0.05, "Where beta starts tuning from.") +add_arg('beta_to', float, 0.36, "Where beta ends tuning with.") +add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") +add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") +add_arg('use_gpu', bool, True, "Use GPU or not.") +add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " + "bi-directional RNNs. Not for GRU.") +add_arg('tune_manifest', str, + 'datasets/manifest.test', + "Filepath of manifest to tune.") +add_arg('mean_std_path', str, + 'mean_std.npz', + "Filepath of normalizer's mean & std.") +add_arg('vocab_path', str, + 'datasets/vocab/eng_vocab.txt', + "Filepath of vocabulary.") +add_arg('lang_model_path', str, + 'lm/data/common_crawl_00.prune01111.trie.klm', + "Filepath for language model.") +add_arg('model_path', str, + './checkpoints/params.latest.tar.gz', + "If None, the training starts from scratch, " + "otherwise, it resumes from the pre-trained model.") +add_arg('error_rate_type', str, + 'wer', + "Error rate type for evaluation.", + choices=['wer', 'cer']) +add_arg('specgram_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc.", + choices=['linear', 'mfcc']) +# yapf: disable args = parser.parse_args() @@ -149,13 +69,13 @@ def tune(): raise ValueError("num_betas must be non-negative!") data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, + vocab_filepath=args.vocab_path, + mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_threads_data) + num_threads=1) batch_reader = data_generator.batch_reader_creator( - manifest_path=args.tune_manifest_path, + manifest_path=args.tune_manifest, batch_size=args.num_samples, sortagrad=False, shuffle_method=None) @@ -171,7 +91,7 @@ def tune(): num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, use_gru=args.use_gru, - pretrained_model_path=args.model_filepath, + pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) # create grid for search @@ -184,14 +104,14 @@ def tune(): for alpha, beta in params_grid: result_transcripts = ds2_model.infer_batch( infer_data=tune_data, - decode_method='beam_search', + decoding_method='ctc_beam_search', beam_alpha=alpha, beam_beta=beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, vocab_list=data_generator.vocab_list, - language_model_path=args.language_model_path, - num_processes=args.num_processes_beam_search) + language_model_path=args.lang_model_path, + num_processes=args.num_proc_bsearch) wer_sum, num_ins = 0.0, 0 for target, result in zip(target_transcripts, result_transcripts): wer_sum += wer(target, result) @@ -201,7 +121,7 @@ def tune(): def main(): - utils.print_arguments(args) + print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) tune() diff --git a/deep_speech_2/utils.py b/deep_speech_2/utils.py index 1d51e2042397b4d3010259a8a3174bc969968aec..2e489ade6f28fdce5c6b60b47bc919a55549f046 100644 --- a/deep_speech_2/utils.py +++ b/deep_speech_2/utils.py @@ -3,6 +3,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import distutils.util + def print_arguments(args): """Print argparse's arguments. @@ -19,7 +21,27 @@ def print_arguments(args): :param args: Input argparse.Namespace for printing. :type args: argparse.Namespace """ - print("----- Configuration Arguments -----") - for arg, value in vars(args).iteritems(): + print("----------- Configuration Arguments -----------") + for arg, value in sorted(vars(args).iteritems()): print("%s: %s" % (arg, value)) - print("------------------------------------") + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs)