From a840f85423ffb51f8360496fd7d12e92dd737dbe Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 27 Jul 2017 10:02:54 +0800 Subject: [PATCH] change probs' computation into log scale & add best path decoder --- deploy/__init__.py | 0 deploy/ctc_beam_search_decoder.cpp | 189 ++++++++++++++++++++++------- deploy/ctc_beam_search_decoder.h | 4 + deploy/scorer.cpp | 9 +- deploy/scorer.h | 2 +- deploy/swig_decoder.py | 22 ++++ 6 files changed, 180 insertions(+), 46 deletions(-) create mode 100644 deploy/__init__.py create mode 100644 deploy/swig_decoder.py diff --git a/deploy/__init__.py b/deploy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deploy/ctc_beam_search_decoder.cpp b/deploy/ctc_beam_search_decoder.cpp index a684b30a..af6414a9 100644 --- a/deploy/ctc_beam_search_decoder.cpp +++ b/deploy/ctc_beam_search_decoder.cpp @@ -3,8 +3,11 @@ #include #include #include +#include #include "ctc_beam_search_decoder.h" +typedef float log_prob_type; + template bool pair_comp_first_rev(const std::pair a, const std::pair b) { @@ -17,6 +20,65 @@ bool pair_comp_second_rev(const std::pair a, const std::pair b) return a.second > b.second; } +template +T log_sum_exp(T x, T y) +{ + static T num_min = -std::numeric_limits::max(); + if (x <= -num_min) return y; + if (y <= -num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; +} + +std::string ctc_best_path_decoder(std::vector > probs_seq, + std::vector vocabulary) { + // dimension check + int num_time_steps = probs_seq.size(); + for (int i=0; i max_idx_vec; + double max_prob = 0.0; + int max_idx = 0; + for (int i=0; i idx_vec; + for (int i=0; i0) && max_idx_vec[i]!=max_idx_vec[i-1])) { + std::cout< > ctc_beam_search_decoder(std::vector > probs_seq, int beam_size, @@ -52,106 +114,147 @@ std::vector > // initialize // two sets containing selected and candidate prefixes respectively - std::map prefix_set_prev, prefix_set_next; + std::map prefix_set_prev, prefix_set_next; // probability of prefixes ending with blank and non-blank - std::map probs_b_prev, probs_nb_prev; - std::map probs_b_cur, probs_nb_cur; - prefix_set_prev["\t"] = 1.0; - probs_b_prev["\t"] = 1.0; - probs_nb_prev["\t"] = 0.0; + std::map log_probs_b_prev, log_probs_nb_prev; + std::map log_probs_b_cur, log_probs_nb_cur; + + static log_prob_type NUM_MAX = std::numeric_limits::max(); + prefix_set_prev["\t"] = 0.0; + log_probs_b_prev["\t"] = 0.0; + log_probs_nb_prev["\t"] = -NUM_MAX; for (int time_step=0; time_step prob = probs_seq[time_step]; std::vector > prob_idx; for (int i=0; i(i, prob[i])); } + // pruning of vacobulary + int cutoff_len = prob.size(); if (cutoff_prob < 1.0) { - std::sort(prob_idx.begin(), prob_idx.end(), + std::sort(prob_idx.begin(), + prob_idx.end(), pair_comp_second_rev); - float cum_prob = 0.0; - int cutoff_len = 0; + double cum_prob = 0.0; + cutoff_len = 0; for (int i=0; i= cutoff_prob) break; } prob_idx = std::vector >( prob_idx.begin(), - prob_idx.begin() + cutoff_len); + prob_idx.begin() + cutoff_len); } + + std::vector > log_prob_idx; + for (int i=0; i + (prob_idx[i].first, log(prob_idx[i].second))); + } + // extend prefix - for (std::map::iterator it = prefix_set_prev.begin(); + for (std::map::iterator + it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { std::string l = it->first; if( prefix_set_next.find(l) == prefix_set_next.end()) { - probs_b_cur[l] = probs_nb_cur[l] = 0.0; + log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX; } - for (int index=0; index 1) { - score = ext_scorer->get_score(l.substr(1)); + score = ext_scorer->get_score(l.substr(1), true); } - probs_nb_cur[l_plus] += score * prob_c * ( - probs_b_prev[l] + probs_nb_prev[l]); + log_probs_prev = log_sum_exp(log_probs_b_prev[l], + log_probs_nb_prev[l]); + log_probs_nb_cur[l_plus] = log_sum_exp( + log_probs_nb_cur[l_plus], + score + log_prob_c + log_probs_prev + ); } else { - probs_nb_cur[l_plus] += prob_c * ( - probs_b_prev[l] + probs_nb_prev[l]); + log_probs_prev = log_sum_exp(log_probs_b_prev[l], + log_probs_nb_prev[l]); + log_probs_nb_cur[l_plus] = log_sum_exp( + log_probs_nb_cur[l_plus], + log_prob_c+log_probs_prev + ); } - prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus]; + prefix_set_next[l_plus] = log_sum_exp( + log_probs_nb_cur[l_plus], + log_probs_b_cur[l_plus] + ); } } - prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]; + prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l], + log_probs_nb_cur[l]); } - probs_b_prev = probs_b_cur; - probs_nb_prev = probs_nb_cur; - std::vector > + log_probs_b_prev = log_probs_b_cur; + log_probs_nb_prev = log_probs_nb_cur; + std::vector > prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), - pair_comp_second_rev); - int k = beam_size - (prefix_vec_next.begin(), prefix_vec_next.begin()+k); + pair_comp_second_rev); + int num_prefixes_next = prefix_vec_next.size(); + int k = beam_size ( + prefix_vec_next.begin(), + prefix_vec_next.begin() + k + ); } // post processing std::vector > beam_result; - for (std::map::iterator it = prefix_set_prev.begin(); - it != prefix_set_prev.end(); it++) { - if (it->second > 0.0 && it->first.size() > 1) { - double prob = it->second; + for (std::map::iterator + it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { + if (it->second > -NUM_MAX && it->first.size() > 1) { + log_prob_type log_prob = it->second; std::string sentence = it->first.substr(1); // scoring the last word if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { - prob = prob * ext_scorer->get_score(sentence); + log_prob = log_prob + ext_scorer->get_score(sentence, true); + } + if (log_prob > -NUM_MAX) { + std::pair cur_result(log_prob, sentence); + beam_result.push_back(cur_result); } - double log_prob = log(prob); - beam_result.push_back(std::pair(log_prob, sentence)); } } // sort the result and return diff --git a/deploy/ctc_beam_search_decoder.h b/deploy/ctc_beam_search_decoder.h index a4bb6aa7..de7e7791 100644 --- a/deploy/ctc_beam_search_decoder.h +++ b/deploy/ctc_beam_search_decoder.h @@ -31,5 +31,9 @@ std::vector > Scorer *ext_scorer=NULL, bool nproc=false ); +/* CTC Best Path Decoder + */ +std::string ctc_best_path_decoder(std::vector > probs_seq, + std::vector vocabulary); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index d438ec1b..e9a74b98 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -89,10 +89,15 @@ void Scorer::reset_params(float alpha, float beta) { this->_beta = beta; } -double Scorer::get_score(std::string sentence) { +double Scorer::get_score(std::string sentence, bool log) { double lm_score = language_model_score(sentence); int word_cnt = word_count(sentence); - double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); + double final_score = 0.0; + if (log == false) { + final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); + } else { + final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt); + } return final_score; } diff --git a/deploy/scorer.h b/deploy/scorer.h index 7b305772..a18e119b 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -30,7 +30,7 @@ public: // reset params alpha & beta void reset_params(float alpha, float beta); // get the final score - double get_score(std::string); + double get_score(std::string, bool log=false); }; #endif //SCORER_H_ diff --git a/deploy/swig_decoder.py b/deploy/swig_decoder.py new file mode 100644 index 00000000..fed23c9e --- /dev/null +++ b/deploy/swig_decoder.py @@ -0,0 +1,22 @@ +"""Contains various CTC decoders in SWIG.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from swig_ctc_beam_search_decoder import ctc_beam_search_decoder as beam_search_decoder +from swig_ctc_beam_search_decoder import ctc_best_path_decoder as best_path__decoder + + +def ctc_best_path_decoder(probs_seq, vocabulary): + best_path__decoder(probs_seq.to_list(), vocabulary) + + +def ctc_beam_search_decoder( + probs_seq, + beam_size, + vocabulary, + blank_id, + cutoff_prob=1.0, + ext_scoring_func=None, ): + beam_search_decoder(probs_seq.to_list(), beam_size, vocabulary, blank_id, + cutoff_prob, ext_scoring_func) -- GitLab