diff --git a/deep_speech_2/data_utils/featurizer/text_featurizer.py b/deep_speech_2/data_utils/featurizer/text_featurizer.py index 95dc637e0d76cc310cc732bd058215cedf9b007c..89202163ca8d8b69f59b858db5451882d7e089b3 100644 --- a/deep_speech_2/data_utils/featurizer/text_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/text_featurizer.py @@ -22,8 +22,6 @@ class TextFeaturizer(object): def __init__(self, vocab_filepath): self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) - # from unicode to string - self._vocab_list = [chars.encode("utf-8") for chars in self._vocab_list] def featurize(self, text): """Convert text string to a list of token indices in char-level.Note diff --git a/deep_speech_2/decoders/decoder_deprecated.py b/deep_speech_2/decoders/decoder_deprecated.py index ffba2731a06b49105f74ab2c47831105c4c68428..6474316329179c5402572b9255bc1f526a51030f 100644 --- a/deep_speech_2/decoders/decoder_deprecated.py +++ b/deep_speech_2/decoders/decoder_deprecated.py @@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, - blank_id, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None, nproc=False): """CTC Beam search decoder. @@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank. - :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float @@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq, raise ValueError("The shape of prob_seq does not match with the " "shape of the vocabulary.") - # blank_id check - if not blank_id < len(probs_seq[0]): - raise ValueError("blank_id shouldn't be greater than probs dimension") + # blank_id assign + blank_id = len(vocabulary) # If the decoder called in the multiprocesses, then use the global scorer # instantiated in ctc_beam_search_decoder_batch(). @@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq, prob_idx = list(enumerate(probs_seq[time_step])) cutoff_len = len(prob_idx) #If pruning is enabled - if cutoff_prob < 1.0: + if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len: prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) cutoff_len, cum_prob = 0, 0.0 for i in xrange(len(prob_idx)): @@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq, cutoff_len += 1 if cum_prob >= cutoff_prob: break + cutoff_len = min(cutoff_top_n, cutoff_top_n) prob_idx = prob_idx[0:cutoff_len] for l in prefix_set_prev: @@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, - blank_id, num_processes, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None): """CTC beam search decoder using multiple processes. @@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank. - :type blank_id: int :param num_processes: Number of parallel processes. :type num_processes: int :param cutoff_prob: Cutoff probability in pruning, @@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, - nproc) + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, + cutoff_top_n, None, nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/deep_speech_2/decoders/lm_scorer_deprecated.py b/deep_speech_2/decoders/lm_scorer_deprecated.py index 463e96d6653b29207fb6105527a1f79c41c7fb84..c6a661030d4363727e259da9c7949e59705d55c8 100644 --- a/deep_speech_2/decoders/lm_scorer_deprecated.py +++ b/deep_speech_2/decoders/lm_scorer_deprecated.py @@ -8,7 +8,7 @@ import kenlm import numpy as np -class LmScorer(object): +class Scorer(object): """External scorer to evaluate a prefix or whole sentence in beam search decoding, including the score from n-gram language model and word count. diff --git a/deep_speech_2/decoders/swig/ctc_decoders.cpp b/deep_speech_2/decoders/swig/ctc_decoders.cpp index 86598eee6e0513d74111a3702586549e59ef1464..35425fbca3607ee9897a35c5e20300d1b95f7677 100644 --- a/deep_speech_2/decoders/swig/ctc_decoders.cpp +++ b/deep_speech_2/decoders/swig/ctc_decoders.cpp @@ -128,7 +128,7 @@ std::vector> ctc_beam_search_decoder( // pruning of vacobulary size_t cutoff_len = prob.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { + if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { std::sort( prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); if (cutoff_prob < 1.0) { diff --git a/deep_speech_2/examples/librispeech/run_infer.sh b/deep_speech_2/examples/librispeech/run_infer.sh index fa177933af576a151b1840ac4430f3168153789a..b6f254a0bfb9d636f331249480fe1a995664e2fc 100644 --- a/deep_speech_2/examples/librispeech/run_infer.sh +++ b/deep_speech_2/examples/librispeech/run_infer.sh @@ -24,6 +24,7 @@ python -u infer.py \ --alpha=2.15 \ --beta=0.35 \ --cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/librispeech/run_infer_golden.sh b/deep_speech_2/examples/librispeech/run_infer_golden.sh index 20dfc65ee8ff10656e94fc7fefac506742cd147f..9336edebb0e6f2208c536b311d6491999a7ed705 100644 --- a/deep_speech_2/examples/librispeech/run_infer_golden.sh +++ b/deep_speech_2/examples/librispeech/run_infer_golden.sh @@ -33,6 +33,7 @@ python -u infer.py \ --alpha=2.15 \ --beta=0.35 \ --cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/examples/librispeech/run_test_golden.sh b/deep_speech_2/examples/librispeech/run_test_golden.sh index e539bd0137251e1d81503511aae7e4b02b8d5e96..6aed4cfca1419c6dc8fcfb957db47712cad8408f 100644 --- a/deep_speech_2/examples/librispeech/run_test_golden.sh +++ b/deep_speech_2/examples/librispeech/run_test_golden.sh @@ -34,6 +34,7 @@ python -u test.py \ --alpha=2.15 \ --beta=0.35 \ --cutoff_prob=1.0 \ +--cutoff_top_n=40 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 5da1db970c13c7356f0e2b8ca05efa072ac2ba8e..1064fd25a0bc3730c4a8dfa4e6668cbd1ede3ef4 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -23,7 +23,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " @@ -85,6 +86,9 @@ def infer(): pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) + # decoders only accept string encoded in utf-8 + vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + result_transcripts = ds2_model.infer_batch( infer_data=infer_data, decoding_method=args.decoding_method, @@ -92,7 +96,8 @@ def infer(): beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) diff --git a/deep_speech_2/model_utils/model.py b/deep_speech_2/model_utils/model.py index 1a9910e9d9f5a0129a22c7e93cbd1c4a272eb89e..4f5021a6d5c95beeb42b64ea639ca88040069b17 100644 --- a/deep_speech_2/model_utils/model.py +++ b/deep_speech_2/model_utils/model.py @@ -148,8 +148,8 @@ class DeepSpeech2Model(object): return self._loss_inferer.infer(input=infer_data) def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, - beam_size, cutoff_prob, vocab_list, language_model_path, - num_processes): + beam_size, cutoff_prob, cutoff_top_n, vocab_list, + language_model_path, num_processes): """Model inference. Infer the transcription for a batch of speech utterances. @@ -169,6 +169,10 @@ class DeepSpeech2Model(object): :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int :param vocab_list: List of tokens in the vocabulary, for decoding. :type vocab_list: list :param language_model_path: Filepath for language model. @@ -216,7 +220,8 @@ class DeepSpeech2Model(object): beam_size=beam_size, num_processes=num_processes, ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob) + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) results = [result[0][1] for result in beam_search_results] else: diff --git a/deep_speech_2/test.py b/deep_speech_2/test.py index 76efb4d1e196fcfe40358b028bfe966f224fb8eb..c564bb85db653668c04bc03a69a1e909e3a67cdc 100644 --- a/deep_speech_2/test.py +++ b/deep_speech_2/test.py @@ -24,7 +24,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") +add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " @@ -85,6 +86,9 @@ def evaluate(): pretrained_model_path=args.model_path, share_rnn_weights=args.share_rnn_weights) + # decoders only accept string encoded in utf-8 + vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + error_rate_func = cer if args.error_rate_type == 'cer' else wer error_sum, num_ins = 0.0, 0 for infer_data in batch_reader(): @@ -95,7 +99,8 @@ def evaluate(): beam_beta=args.beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - vocab_list=data_generator.vocab_list, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) target_transcripts = [