#include #include #include #include #include #include "ctc_beam_search_decoder.h" template bool pair_comp_first_rev(const std::pair a, const std::pair b) { return a.first > b.first; } template bool pair_comp_second_rev(const std::pair a, const std::pair b) { return a.second > b.second; } /* CTC beam search decoder in C++, the interface is consistent with the original decoder in Python version. */ std::vector > ctc_beam_search_decoder(std::vector > probs_seq, int beam_size, std::vector vocabulary, int blank_id, double cutoff_prob, Scorer *ext_scorer, bool nproc ) { int num_time_steps = probs_seq.size(); // assign space ID std::vector::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it-vocabulary.begin(); if(space_id >= vocabulary.size()) { std::cout<<"The character space is not in the vocabulary!"; exit(1); } // initialize // two sets containing selected and candidate prefixes respectively std::map prefix_set_prev, prefix_set_next; // probability of prefixes ending with blank and non-blank std::map probs_b_prev, probs_nb_prev; std::map probs_b_cur, probs_nb_cur; prefix_set_prev["\t"] = 1.0; probs_b_prev["\t"] = 1.0; probs_nb_prev["\t"] = 0.0; for (int time_step=0; time_step prob = probs_seq[time_step]; std::vector > prob_idx; for (int i=0; i(i, prob[i])); } // pruning of vacobulary if (cutoff_prob < 1.0) { std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); float cum_prob = 0.0; int cutoff_len = 0; for (int i=0; i= cutoff_prob) break; } prob_idx = std::vector >(prob_idx.begin(), prob_idx.begin()+cutoff_len); } // extend prefix for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { std::string l = it->first; if( prefix_set_next.find(l) == prefix_set_next.end()) { probs_b_cur[l] = probs_nb_cur[l] = 0.0; } for (int index=0; index 1) { score = ext_scorer->get_score(l.substr(1)); } probs_nb_cur[l_plus] += score * prob_c * ( probs_b_prev[l] + probs_nb_prev[l]); } else { probs_nb_cur[l_plus] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]); } prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus]; } } prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; } probs_b_prev = probs_b_cur; probs_nb_prev = probs_nb_cur; std::vector > prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev); int k = beam_size (prefix_vec_next.begin(), prefix_vec_next.begin()+k); } // post processing std::vector > beam_result; for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { if (it->second > 0.0 && it->first.size() > 1) { double prob = it->second; std::string sentence = it->first.substr(1); // scoring the last word if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { prob = prob * ext_scorer->get_score(sentence); } double log_prob = log(prob); beam_result.push_back(std::pair(log_prob, sentence)); } } // sort the result and return std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev); return beam_result; }