diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp index 5c8373beaa1a269e4970bfb86b0df0d4167d014e..624784b05e215782f2264cc6ae4db7eed5b28cae 100644 --- a/decoders/swig/ctc_beam_search_decoder.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -9,7 +9,6 @@ #include "ThreadPool.h" #include "fst/fstlib.h" -#include "fst/log.h" #include "decoder_utils.h" #include "path_trie.h" @@ -130,7 +129,7 @@ std::vector> ctc_beam_search_decoder( log_sum_exp(prefix_new->log_prob_nb_cur, log_p); } } // end of loop over prefix - } // end of loop over chars + } // end of loop over vocabulary prefixes.clear(); // update log probs diff --git a/decoders/swig/ctc_greedy_decoder.cpp b/decoders/swig/ctc_greedy_decoder.cpp index c4c94539ee15811e6d7bac496339ff02b4f239f8..03449d7391514bd267b396bab31da2e498425b47 100644 --- a/decoders/swig/ctc_greedy_decoder.cpp +++ b/decoders/swig/ctc_greedy_decoder.cpp @@ -27,7 +27,7 @@ std::string ctc_greedy_decoder( max_prob = probs_step[j]; } } - // id with maximum probability in current step + // id with maximum probability in current time step max_idx_vec[i] = max_idx; // deduplicate if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { diff --git a/decoders/swig/ctc_greedy_decoder.h b/decoders/swig/ctc_greedy_decoder.h index 043742f26e82c97746c7add49621a740164703d0..5e64f692e5062c4f8f1d4aa0c8c7b75eaa0a668a 100644 --- a/decoders/swig/ctc_greedy_decoder.h +++ b/decoders/swig/ctc_greedy_decoder.h @@ -14,7 +14,7 @@ * The decoding result in string */ std::string ctc_greedy_decoder( - const std::vector> &probs_seq, - const std::vector &vocabulary); + const std::vector>& probs_seq, + const std::vector& vocabulary); #endif // CTC_GREEDY_DECODER_H diff --git a/decoders/swig/decoder_utils.cpp b/decoders/swig/decoder_utils.cpp index 665fcc22f8fa15beb0dabb7b40eadd363cfd02aa..70a1592889bc24e0af344fc123101e1b19ff6c15 100644 --- a/decoders/swig/decoder_utils.cpp +++ b/decoders/swig/decoder_utils.cpp @@ -23,10 +23,9 @@ std::vector> get_pruned_log_probs( for (size_t i = 0; i < prob_idx.size(); ++i) { cum_prob += prob_idx[i].second; cutoff_len += 1; - if (cum_prob >= cutoff_prob) break; + if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; } } - cutoff_len = std::min(cutoff_len, cutoff_top_n); prob_idx = std::vector>( prob_idx.begin(), prob_idx.begin() + cutoff_len); } diff --git a/decoders/swig/decoder_utils.h b/decoders/swig/decoder_utils.h index 932ffb12f793c228c7ac6c8b249283c787ab31a1..72821c187fff9567645de82b1c45fc0350787173 100644 --- a/decoders/swig/decoder_utils.h +++ b/decoders/swig/decoder_utils.h @@ -2,8 +2,8 @@ #define DECODER_UTILS_H_ #include -#include "path_trie.h" #include "fst/log.h" +#include "path_trie.h" const float NUM_FLT_INF = std::numeric_limits::max(); const float NUM_FLT_MIN = std::numeric_limits::min(); diff --git a/decoders/swig/path_trie.cpp b/decoders/swig/path_trie.cpp index fdff32861190e5f589c0041c62d6487a574c702d..40d9097055686eb9de0834fa2ab9478988bbfa96 100644 --- a/decoders/swig/path_trie.cpp +++ b/decoders/swig/path_trie.cpp @@ -19,9 +19,11 @@ PathTrie::PathTrie() { character = ROOT_; exists_ = true; parent = nullptr; + dictionary_ = nullptr; dictionary_state_ = 0; has_dictionary_ = false; + matcher_ = nullptr; } diff --git a/decoders/swig/scorer.cpp b/decoders/swig/scorer.cpp index 27c31fa7147a4ab1d648644ac874b77281a079df..686c67c77e1d7d1267df963308582bf194859310 100644 --- a/decoders/swig/scorer.cpp +++ b/decoders/swig/scorer.cpp @@ -19,9 +19,11 @@ Scorer::Scorer(double alpha, const std::vector& vocab_list) { this->alpha = alpha; this->beta = beta; + + dictionary = nullptr; is_character_based_ = true; language_model_ = nullptr; - dictionary = nullptr; + max_order_ = 0; dict_size_ = 0; SPACE_ID_ = -1;