#include #include #include #include #include #include #include "ctc_decoders.h" #include "decoder_utils.h" typedef double log_prob_type; template bool pair_comp_first_rev(const std::pair a, const std::pair b) { return a.first > b.first; } template bool pair_comp_second_rev(const std::pair a, const std::pair b) { return a.second > b.second; } template T log_sum_exp(T x, T y) { static T num_min = -std::numeric_limits::max(); if (x <= num_min) return y; if (y <= num_min) return x; T xmax = std::max(x, y); return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; } std::string ctc_best_path_decoder(std::vector > probs_seq, std::vector vocabulary) { // dimension check int num_time_steps = probs_seq.size(); for (int i=0; i max_idx_vec; double max_prob = 0.0; int max_idx = 0; for (int i=0; i idx_vec; for (int i=0; i0) && max_idx_vec[i]!=max_idx_vec[i-1])) { idx_vec.push_back(max_idx_vec[i]); } } std::string best_path_result; for (int i=0; i > ctc_beam_search_decoder(std::vector > probs_seq, int beam_size, std::vector vocabulary, int blank_id, double cutoff_prob, Scorer *ext_scorer, bool nproc) { // dimension check int num_time_steps = probs_seq.size(); for (int i=0; i vocabulary.size()) { std::cout << " Invalid blank_id! " << std::endl; exit(1); } // assign space ID std::vector::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it - vocabulary.begin(); if(space_id >= vocabulary.size()) { std::cout << " The character space is not in the vocabulary!"< prefix_set_prev, prefix_set_next; // probability of prefixes ending with blank and non-blank std::map log_probs_b_prev, log_probs_nb_prev; std::map log_probs_b_cur, log_probs_nb_cur; static log_prob_type NUM_MAX = std::numeric_limits::max(); prefix_set_prev["\t"] = 0.0; log_probs_b_prev["\t"] = 0.0; log_probs_nb_prev["\t"] = -NUM_MAX; for (int time_step=0; time_step prob = probs_seq[time_step]; std::vector > prob_idx; for (int i=0; i(i, prob[i])); } // pruning of vacobulary int cutoff_len = prob.size(); if (cutoff_prob < 1.0) { std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); double cum_prob = 0.0; cutoff_len = 0; for (int i=0; i= cutoff_prob) break; } prob_idx = std::vector >( prob_idx.begin(), prob_idx.begin() + cutoff_len); } std::vector > log_prob_idx; for (int i=0; i (prob_idx[i].first, log(prob_idx[i].second))); } // extend prefix for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { std::string l = it->first; if( prefix_set_next.find(l) == prefix_set_next.end()) { log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX; } for (int index=0; index 1) { score = ext_scorer->get_score(l.substr(1), true); } log_probs_prev = log_sum_exp(log_probs_b_prev[l], log_probs_nb_prev[l]); log_probs_nb_cur[l_plus] = log_sum_exp( log_probs_nb_cur[l_plus], score + log_prob_c + log_probs_prev ); } else { log_probs_prev = log_sum_exp(log_probs_b_prev[l], log_probs_nb_prev[l]); log_probs_nb_cur[l_plus] = log_sum_exp( log_probs_nb_cur[l_plus], log_prob_c+log_probs_prev ); } prefix_set_next[l_plus] = log_sum_exp( log_probs_nb_cur[l_plus], log_probs_b_cur[l_plus] ); } } prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l], log_probs_nb_cur[l]); } log_probs_b_prev = log_probs_b_cur; log_probs_nb_prev = log_probs_nb_cur; std::vector > prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev); int num_prefixes_next = prefix_vec_next.size(); int k = beam_size ( prefix_vec_next.begin(), prefix_vec_next.begin() + k ); } // post processing std::vector > beam_result; for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { if (it->second > -NUM_MAX && it->first.size() > 1) { log_prob_type log_prob = it->second; std::string sentence = it->first.substr(1); // scoring the last word if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { log_prob = log_prob + ext_scorer->get_score(sentence, true); } if (log_prob > -NUM_MAX) { std::pair cur_result(log_prob, sentence); beam_result.push_back(cur_result); } } } // sort the result and return std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev); return beam_result; }