diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index 4dcc7c899934e25b13bce6ca2b03c6623cc05e7d..8469a194dfe461192b2280a43c7b2c6e3bf64cdd 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -35,7 +35,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { @@ -45,19 +46,13 @@ std::vector> ctc_beam_search_decoder( "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - - // assign blank id - // size_t blank_id = vocabulary.size(); - size_t blank_id = 0; - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); int space_id = it - vocabulary.begin(); // if no space in vocabulary if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } - // init prefixes' root PathTrie root; root.score = root.log_prob_b_prev = 0.0; @@ -218,7 +213,8 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool ThreadPool pool(num_processes); @@ -234,7 +230,8 @@ ctc_beam_search_decoder_batch( beam_size, cutoff_prob, cutoff_top_n, - ext_scorer)); + ext_scorer, + blank_id)); } // get decoding results diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.h b/deepspeech/decoders/swig/ctc_beam_search_decoder.h index c31510da34dc5444b0eab05ab682291902cc0bda..eaba9da8c8a05ae10391d81f30c368f125524aac 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.h +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.h @@ -43,7 +43,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); /* CTC Beam Search Decoder for batch data @@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index 1c735c424bec1cc2bfb33f9d848ff63f03faaf14..18008cced8be1929c3c7e044b4e28b79bb8eeb15 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -17,17 +17,18 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, - const std::vector &vocabulary) { + const std::vector &vocabulary, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, + vocabulary.size(), "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - size_t blank_id = vocabulary.size(); + // size_t blank_id = vocabulary.size(); std::vector max_idx_vec(num_time_steps, 0); std::vector idx_vec; diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.h b/deepspeech/decoders/swig/ctc_greedy_decoder.h index 5e8c5c251fd76c8618504e96361cb692cfc6ed43..dd1b333153510766ed64875200825d145a374186 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.h +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.h @@ -29,6 +29,7 @@ */ std::string ctc_greedy_decoder( const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector& vocabulary, + size_t blank_id); #endif // CTC_GREEDY_DECODER_H diff --git a/deepspeech/decoders/swig/decoder_utils.h b/deepspeech/decoders/swig/decoder_utils.h index a874e439f7f298c51732668aa4870d6a9601148d..96399c77820342b551e6797b694bb37256b81b1c 100644 --- a/deepspeech/decoders/swig/decoder_utils.h +++ b/deepspeech/decoders/swig/decoder_utils.h @@ -15,10 +15,12 @@ #ifndef DECODER_UTILS_H_ #define DECODER_UTILS_H_ +#include #include #include "fst/log.h" #include "path_trie.h" +const std::string kSPACE = ""; const float NUM_FLT_INF = std::numeric_limits::max(); const float NUM_FLT_MIN = std::numeric_limits::min(); diff --git a/deepspeech/decoders/swig/scorer.cpp b/deepspeech/decoders/swig/scorer.cpp index a25382b15cd73e2743d28b8e3e93d167d976fe2e..7bd6542dfef1c003327161c6fb886e094cef0c2c 100644 --- a/deepspeech/decoders/swig/scorer.cpp +++ b/deepspeech/decoders/swig/scorer.cpp @@ -165,7 +165,7 @@ void Scorer::set_char_map(const std::vector& char_list) { // Set the char map for the FST for spelling correction for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == " ") { + if (char_list_[i] == kSPACE) { SPACE_ID_ = i; } // The initial state of FST is state 0, hence the index of chars in diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 86af475afc8831460874be7a13b37e3a5fe99be5..c089f96cd75f41ac336a20a4b67955b4928d13f4 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -83,10 +83,12 @@ FILES = glob.glob('kenlm/util/*.cc') \ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') +# yapf: disable FILES = [ fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith('unittest.cc')) ] +# yapf: enable LIBS = ['stdc++'] if platform.system() != 'Darwin': diff --git a/deepspeech/decoders/swig_wrapper.py b/deepspeech/decoders/swig_wrapper.py index 3ffdb9c74d73f1863e9ba730a0ac03b2eafc7fce..d883d430cfa2db91c7d15474290df60bfc46e778 100644 --- a/deepspeech/decoders/swig_wrapper.py +++ b/deepspeech/decoders/swig_wrapper.py @@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary): +def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time @@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, + blank_id) return result @@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, beam_size, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, - ext_scoring_func) + ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results @@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, - cutoff_top_n, ext_scoring_func) + cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results