diff --git a/deploy/ctc_beam_search_decoder.cpp b/deploy/ctc_beam_search_decoder.cpp index 68d1a8457e6f214060efc7da9e637e9ff6f59a34..a684b30a61e3230a0bc8fcdba91146ab9d6f0db3 100644 --- a/deploy/ctc_beam_search_decoder.cpp +++ b/deploy/ctc_beam_search_decoder.cpp @@ -6,35 +6,47 @@ #include "ctc_beam_search_decoder.h" template -bool pair_comp_first_rev(const std::pair a, const std::pair b) { +bool pair_comp_first_rev(const std::pair a, const std::pair b) +{ return a.first > b.first; } template -bool pair_comp_second_rev(const std::pair a, const std::pair b) { +bool pair_comp_second_rev(const std::pair a, const std::pair b) +{ return a.second > b.second; } -/* CTC beam search decoder in C++, the interface is consistent with the original - decoder in Python version. -*/ std::vector > - ctc_beam_search_decoder(std::vector > probs_seq, - int beam_size, - std::vector vocabulary, - int blank_id, - double cutoff_prob, - Scorer *ext_scorer, - bool nproc - ) -{ + ctc_beam_search_decoder(std::vector > probs_seq, + int beam_size, + std::vector vocabulary, + int blank_id, + double cutoff_prob, + Scorer *ext_scorer, + bool nproc) { + // dimension check int num_time_steps = probs_seq.size(); + for (int i=0; i vocabulary.size()) { + std::cout<<"Invalid blank_id!"<::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); - int space_id = it-vocabulary.begin(); + std::vector::iterator it = std::find(vocabulary.begin(), + vocabulary.end(), " "); + int space_id = it - vocabulary.begin(); if(space_id >= vocabulary.size()) { - std::cout<<"The character space is not in the vocabulary!"; + std::cout<<"The character space is not in the vocabulary!"< > } // pruning of vacobulary if (cutoff_prob < 1.0) { - std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); + std::sort(prob_idx.begin(), prob_idx.end(), + pair_comp_second_rev); float cum_prob = 0.0; int cutoff_len = 0; for (int i=0; i > cutoff_len += 1; if (cum_prob >= cutoff_prob) break; } - prob_idx = std::vector >(prob_idx.begin(), prob_idx.begin()+cutoff_len); + prob_idx = std::vector >( prob_idx.begin(), + prob_idx.begin() + cutoff_len); } // extend prefix for (std::map::iterator it = prefix_set_prev.begin(); @@ -82,11 +96,11 @@ std::vector > int c = prob_idx[index].first; double prob_c = prob_idx[index].second; if (c == blank_id) { - probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]); + probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]); } else { std::string last_char = l.substr(l.size()-1, 1); std::string new_char = vocabulary[c]; - std::string l_plus = l+new_char; + std::string l_plus = l + new_char; if( prefix_set_next.find(l_plus) == prefix_set_next.end()) { probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0; @@ -105,19 +119,22 @@ std::vector > probs_nb_cur[l_plus] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]); } - prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus]; + prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus]; } } - prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]; } probs_b_prev = probs_b_cur; probs_nb_prev = probs_nb_cur; std::vector > - prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); - std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev); - int k = beam_size); + int k = beam_size (prefix_vec_next.begin(), prefix_vec_next.begin()+k); } @@ -138,6 +155,7 @@ std::vector > } } // sort the result and return - std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev); + std::sort(beam_result.begin(), beam_result.end(), + pair_comp_first_rev); return beam_result; } diff --git a/deploy/ctc_beam_search_decoder.h b/deploy/ctc_beam_search_decoder.h index d23252aceb3302d46ad9a468061fb736d0dc89f4..a4bb6aa741940077c530f12419dc59abee74fd01 100644 --- a/deploy/ctc_beam_search_decoder.h +++ b/deploy/ctc_beam_search_decoder.h @@ -6,14 +6,30 @@ #include #include "scorer.h" -std::vector > - ctc_beam_search_decoder(std::vector > probs_seq, - int beam_size, - std::vector vocabulary, - int blank_id=0, - double cutoff_prob=1.0, - Scorer *ext_scorer=NULL, - bool nproc=false - ); +/* CTC Beam Search Decoder, the interface is consistent with the + * original decoder in Python version. + + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * beam_size: The width of beam search. + * vocabulary: A vector of vocabulary. + * blank_id: ID of blank. + * cutoff_prob: Cutoff probability of pruning + * ext_scorer: External scorer to evaluate a prefix. + * nproc: Whether this function used in multiprocessing. + * Return: + * A vector that each element is a pair of score and decoding result, + * in desending order. +*/ +std::vector > + ctc_beam_search_decoder(std::vector > probs_seq, + int beam_size, + std::vector vocabulary, + int blank_id, + double cutoff_prob=1.0, + Scorer *ext_scorer=NULL, + bool nproc=false + ); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deploy/decoder_setup.py b/deploy/decoder_setup.py index 5201172b1fcb0c385692b9341f6e70d23f3da118..4ed603b25239f1a19611e9ecb347c80fc58568a8 100644 --- a/deploy/decoder_setup.py +++ b/deploy/decoder_setup.py @@ -10,8 +10,8 @@ def compile_test(header, library): return os.system(command) == 0 -FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob( - 'util/double-conversion/*.cc') +FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob( + 'kenlm/util/double-conversion/*.cc') FILES = [ fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) ] @@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [ 'ctc_beam_search_decoder.cpp' ], language='C++', - include_dirs=['.'], + include_dirs=['.', './kenlm'], libraries=LIBS, extra_compile_args=ARGS) ] @@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [ setup( name='swig_ctc_beam_search_decoder', version='0.1', - author='Yibing Liu', description="""CTC beam search decoder""", ext_modules=ctc_beam_search_decoder_module, py_modules=['swig_ctc_beam_search_decoder'], ) diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index d7f68d71f8d3b8b326295ac77c3017ca63ce2606..1b843402bf6f59b1cadf21f2fdfa1f8b1e2d1b04 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -1,5 +1,4 @@ #include - #include "scorer.h" #include "lm/model.hh" #include "util/tokenize_piece.hh" @@ -17,6 +16,13 @@ Scorer::~Scorer(){ delete (Model *)this->_language_model; } +/* Strip a input sentence + * Parameters: + * str: A reference to the objective string + * ch: The character to prune + * Return: + * void + */ inline void strip(std::string &str, char ch=' ') { if (str.size() == 0) return; int start = 0; @@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) { } //log10 prob double log_prob = ret.prob; - return log_prob; } +void Scorer::reset_params(float alpha, float beta) { + this->_alpha = alpha; + this->_beta = beta; +} + double Scorer::get_score(std::string sentence) { double lm_score = language_model_score(sentence); int word_cnt = word_count(sentence); diff --git a/deploy/scorer.h b/deploy/scorer.h index 47261bb519faca41295054dfd2c619ed64e0b3dc..7b305772c536dbc073950b247b12b7a0a30075b9 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -3,20 +3,34 @@ #include +/* External scorer to evaluate a prefix or a complete sentence + * when a new word appended during decoding, consisting of word + * count and language model scoring. + * Example: + * Scorer ext_scorer(alpha, beta, "path_to_language_model.klm"); + * double score = ext_scorer.get_score("sentence_to_score"); + */ class Scorer{ private: float _alpha; float _beta; void *_language_model; + // word insertion term + int word_count(std::string); + // n-gram language model scoring + double language_model_score(std::string); + public: Scorer(){} Scorer(float alpha, float beta, std::string lm_model_path); ~Scorer(); - int word_count(std::string); - double language_model_score(std::string); + + // reset params alpha & beta + void reset_params(float alpha, float beta); + // get the final score double get_score(std::string); }; -#endif +#endif //SCORER_H_ diff --git a/deploy/scorer_setup.py b/deploy/scorer_setup.py index c0006e07175eeadb6b9d704c8e559f4b27923885..3bb582724a23822856ceb50a2f736556216af53f 100644 --- a/deploy/scorer_setup.py +++ b/deploy/scorer_setup.py @@ -10,8 +10,8 @@ def compile_test(header, library): return os.system(command) == 0 -FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob( - 'util/double-conversion/*.cc') +FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob( + 'kenlm/util/double-conversion/*.cc') FILES = [ fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) ] @@ -41,7 +41,7 @@ ext_modules = [ name='_swig_scorer', sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'], language='C++', - include_dirs=['.'], + include_dirs=['.', './kenlm'], libraries=LIBS, extra_compile_args=ARGS) ]