diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index f4dd797dc6ac3015eff51e9baa4b0b406126721d..4dcc7c899934e25b13bce6ca2b03c6623cc05e7d 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -36,169 +36,177 @@ std::vector> ctc_beam_search_decoder( double cutoff_prob, size_t cutoff_top_n, Scorer *ext_scorer) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // assign blank id - //size_t blank_id = vocabulary.size(); - size_t blank_id = 0; - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - - // init prefixes' root - PathTrie root; - root.score = root.log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); - - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } - - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + // vocabulary.size() + 1, + vocabulary.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); } - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; + // assign blank id + // size_t blank_id = vocabulary.size(); + size_t blank_id = 0; + + // assign space id + auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + + // init prefixes' root + PathTrie root; + root.score = root.log_prob_b_prev = 0.0; + std::vector prefixes; + prefixes.push_back(&root); + + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = + static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); + + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = log_sum_exp( + prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = + log_sum_exp(prefix->log_prob_nb_cur, + log_prob_c + prefix->log_prob_nb_prev); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based())) { + PathTrie *prefix_to_score = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + score = ext_scorer->get_log_cond_prob(ngram) * + ext_scorer->alpha; + log_p += score; + log_p += ext_scorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + } // end of loop over vocabulary + + + prefixes.clear(); + // update log probs + root.iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; + } // end of loop over time + + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); } - } // end of loop over prefix - } // end of loop over vocabulary - - - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } } - } // end of loop over time - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + double approx_ctc = prefixes[i]->score; + if (ext_scorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= + (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; } - prefixes[i]->approx_ctc = approx_ctc; - } - return get_beam_search_result(prefixes, vocabulary, beam_size); + return get_beam_search_result(prefixes, vocabulary, beam_size); } @@ -211,28 +219,28 @@ ctc_beam_search_decoder_batch( double cutoff_prob, size_t cutoff_top_n, Scorer *ext_scorer) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoder, - probs_split[i], - vocabulary, - beam_size, - cutoff_prob, - cutoff_top_n, - ext_scorer)); - } - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(num_processes); + // number of samples + size_t batch_size = probs_split.size(); + + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + probs_split[i], + vocabulary, + beam_size, + cutoff_prob, + cutoff_top_n, + ext_scorer)); + } + + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; } diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index da028bf838c6387b7619551768098adc19fdf2f5..1c735c424bec1cc2bfb33f9d848ff63f03faaf14 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -18,42 +18,42 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, const std::vector &vocabulary) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + vocabulary.size() + 1, + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } - size_t blank_id = vocabulary.size(); + size_t blank_id = vocabulary.size(); - std::vector max_idx_vec(num_time_steps, 0); - std::vector idx_vec; - for (size_t i = 0; i < num_time_steps; ++i) { - double max_prob = 0.0; - size_t max_idx = 0; - const std::vector &probs_step = probs_seq[i]; - for (size_t j = 0; j < probs_step.size(); ++j) { - if (max_prob < probs_step[j]) { - max_idx = j; - max_prob = probs_step[j]; - } - } - // id with maximum probability in current time step - max_idx_vec[i] = max_idx; - // deduplicate - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { - idx_vec.push_back(max_idx_vec[i]); + std::vector max_idx_vec(num_time_steps, 0); + std::vector idx_vec; + for (size_t i = 0; i < num_time_steps; ++i) { + double max_prob = 0.0; + size_t max_idx = 0; + const std::vector &probs_step = probs_seq[i]; + for (size_t j = 0; j < probs_step.size(); ++j) { + if (max_prob < probs_step[j]) { + max_idx = j; + max_prob = probs_step[j]; + } + } + // id with maximum probability in current time step + max_idx_vec[i] = max_idx; + // deduplicate + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { + idx_vec.push_back(max_idx_vec[i]); + } } - } - std::string best_path_result; - for (size_t i = 0; i < idx_vec.size(); ++i) { - if (idx_vec[i] != blank_id) { - best_path_result += vocabulary[idx_vec[i]]; + std::string best_path_result; + for (size_t i = 0; i < idx_vec.size(); ++i) { + if (idx_vec[i] != blank_id) { + best_path_result += vocabulary[idx_vec[i]]; + } } - } - return best_path_result; + return best_path_result; } diff --git a/deepspeech/decoders/swig/decoder_utils.cpp b/deepspeech/decoders/swig/decoder_utils.cpp index a10e07f0cb363c21db62b43903390f1be4bd5b13..43f62e7533a19e834ba94cf60487056ffb42fa3d 100644 --- a/deepspeech/decoders/swig/decoder_utils.cpp +++ b/deepspeech/decoders/swig/decoder_utils.cpp @@ -22,33 +22,35 @@ std::vector> get_pruned_log_probs( const std::vector &prob_step, double cutoff_prob, size_t cutoff_top_n) { - std::vector> prob_idx; - for (size_t i = 0; i < prob_step.size(); ++i) { - prob_idx.push_back(std::pair(i, prob_step[i])); - } - // pruning of vacobulary - size_t cutoff_len = prob_step.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { - std::sort( - prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob += prob_idx[i].second; - cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; - } + std::vector> prob_idx; + for (size_t i = 0; i < prob_step.size(); ++i) { + prob_idx.push_back(std::pair(i, prob_step[i])); } - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } - return log_prob_idx; + // pruning of vacobulary + size_t cutoff_len = prob_step.size(); + if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { + std::sort(prob_idx.begin(), + prob_idx.end(), + pair_comp_second_rev); + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (size_t i = 0; i < prob_idx.size(); ++i) { + cum_prob += prob_idx[i].second; + cutoff_len += 1; + if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) + break; + } + } + prob_idx = std::vector>( + prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (size_t i = 0; i < cutoff_len; ++i) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); + } + return log_prob_idx; } @@ -56,106 +58,106 @@ std::vector> get_beam_search_result( const std::vector &prefixes, const std::vector &vocabulary, size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + space_prefixes.push_back(prefixes[i]); + } } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (size_t j = 0; j < output.size(); j++) { - output_str += vocabulary[output[j]]; + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (size_t j = 0; j < output.size(); j++) { + output_str += vocabulary[output[j]]; + } + std::pair output_pair( + -space_prefixes[i]->approx_ctc, output_str); + output_vecs.emplace_back(output_pair); } - std::pair output_pair(-space_prefixes[i]->approx_ctc, - output_str); - output_vecs.emplace_back(output_pair); - } - return output_vecs; + return output_vecs; } size_t get_utf8_str_len(const std::string &str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; } std::vector split_utf8_str(const std::string &str) { - std::vector result; - std::string out_str; - - for (char c : str) { - if ((c & 0xc0) != 0x80) // new UTF-8 character - { - if (!out_str.empty()) { - result.push_back(out_str); - out_str.clear(); - } + std::vector result; + std::string out_str; + + for (char c : str) { + if ((c & 0xc0) != 0x80) // new UTF-8 character + { + if (!out_str.empty()) { + result.push_back(out_str); + out_str.clear(); + } + } + + out_str.append(1, c); } - - out_str.append(1, c); - } - result.push_back(out_str); - return result; + result.push_back(out_str); + return result; } std::vector split_str(const std::string &s, const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; } - start = end + delim_len; - } - return result; + return result; } bool prefix_compare(const PathTrie *x, const PathTrie *y) { - if (x->score == y->score) { - if (x->character == y->character) { - return false; + if (x->score == y->score) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } } else { - return (x->character < y->character); + return x->score > y->score; } - } else { - return x->score > y->score; - } } void add_word_to_fst(const std::vector &word, fst::StdVectorFst *dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); + if (dictionary->NumStates() == 0) { + fst::StdVectorFst::StateId start = dictionary->AddState(); + assert(start == 0); + dictionary->SetStart(start); + } + fst::StdVectorFst::StateId src = dictionary->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = dictionary->AddState(); + dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + dictionary->SetFinal(dst, fst::StdArc::Weight::One()); } bool add_word_to_dictionary( @@ -164,27 +166,27 @@ bool add_word_to_dictionary( bool add_space, int SPACE_ID, fst::StdVectorFst *dictionary) { - auto characters = split_utf8_str(word); - - std::vector int_word; - - for (auto &c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } + auto characters = split_utf8_str(word); + + std::vector int_word; + + for (auto &c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without adding + } + } } - } - if (add_space) { - int_word.push_back(SPACE_ID); - } + if (add_space) { + int_word.push_back(SPACE_ID); + } - add_word_to_fst(int_word, dictionary); - return true; // return with successful adding + add_word_to_fst(int_word, dictionary); + return true; // return with successful adding } diff --git a/deepspeech/decoders/swig/decoder_utils.h b/deepspeech/decoders/swig/decoder_utils.h index 827258178b963cd6a0d8bdd0c4884c52fbe6fc23..a874e439f7f298c51732668aa4870d6a9601148d 100644 --- a/deepspeech/decoders/swig/decoder_utils.h +++ b/deepspeech/decoders/swig/decoder_utils.h @@ -25,14 +25,14 @@ const float NUM_FLT_MIN = std::numeric_limits::min(); // inline function for validation check inline void check( bool x, const char *expr, const char *file, int line, const char *err) { - if (!x) { - std::cout << "[" << file << ":" << line << "] "; - LOG(FATAL) << "\"" << expr << "\" check failed. " << err; - } + if (!x) { + std::cout << "[" << file << ":" << line << "] "; + LOG(FATAL) << "\"" << expr << "\" check failed. " << err; + } } #define VALID_CHECK(x, info) \ - check(static_cast(x), #x, __FILE__, __LINE__, info) + check(static_cast(x), #x, __FILE__, __LINE__, info) #define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) #define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) #define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) @@ -42,24 +42,24 @@ inline void check( template bool pair_comp_first_rev(const std::pair &a, const std::pair &b) { - return a.first > b.first; + return a.first > b.first; } // Function template for comparing two pairs template bool pair_comp_second_rev(const std::pair &a, const std::pair &b) { - return a.second > b.second; + return a.second > b.second; } // Return the sum of two probabilities in log scale template T log_sum_exp(const T &x, const T &y) { - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) return y; + if (y <= num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; } // Get pruned probability vector for each time step's beam search diff --git a/deepspeech/decoders/swig/path_trie.cpp b/deepspeech/decoders/swig/path_trie.cpp index 392e7ca71a5c489456158e7d2b0c1d63ebcad69d..f52d1157352f8013cebdb1932135d5b3422332e4 100644 --- a/deepspeech/decoders/swig/path_trie.cpp +++ b/deepspeech/decoders/swig/path_trie.cpp @@ -23,140 +23,141 @@ #include "decoder_utils.h" PathTrie::PathTrie() { - log_prob_b_prev = -NUM_FLT_INF; - log_prob_nb_prev = -NUM_FLT_INF; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - score = -NUM_FLT_INF; - - ROOT_ = -1; - character = ROOT_; - exists_ = true; - parent = nullptr; - - dictionary_ = nullptr; - dictionary_state_ = 0; - has_dictionary_ = false; - - matcher_ = nullptr; + log_prob_b_prev = -NUM_FLT_INF; + log_prob_nb_prev = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + score = -NUM_FLT_INF; + + ROOT_ = -1; + character = ROOT_; + exists_ = true; + parent = nullptr; + + dictionary_ = nullptr; + dictionary_state_ = 0; + has_dictionary_ = false; + + matcher_ = nullptr; } PathTrie::~PathTrie() { - for (auto child : children_) { - delete child.second; - } + for (auto child : children_) { + delete child.second; + } } PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = children_.begin(); - for (child = children_.begin(); child != children_.end(); ++child) { - if (child->first == new_char) { - break; - } - } - if (child != children_.end()) { - if (!child->second->exists_) { - child->second->exists_ = true; - child->second->log_prob_b_prev = -NUM_FLT_INF; - child->second->log_prob_nb_prev = -NUM_FLT_INF; - child->second->log_prob_b_cur = -NUM_FLT_INF; - child->second->log_prob_nb_cur = -NUM_FLT_INF; + auto child = children_.begin(); + for (child = children_.begin(); child != children_.end(); ++child) { + if (child->first == new_char) { + break; + } } - return (child->second); - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); + if (child != children_.end()) { + if (!child->second->exists_) { + child->second->exists_ = true; + child->second->log_prob_b_prev = -NUM_FLT_INF; + child->second->log_prob_nb_prev = -NUM_FLT_INF; + child->second->log_prob_b_cur = -NUM_FLT_INF; + child->second->log_prob_nb_cur = -NUM_FLT_INF; } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->dictionary_state_ = matcher_->Value().nextstate; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } + return (child->second); } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; + if (has_dictionary_) { + matcher_->SetState(dictionary_state_); + bool found = matcher_->Find(new_char + 1); + if (!found) { + // Adding this character causes word outside dictionary + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = dictionary_->Final(dictionary_state_); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + dictionary_state_ = dictionary_->Start(); + } + return nullptr; + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->dictionary_ = dictionary_; + new_path->dictionary_state_ = matcher_->Value().nextstate; + new_path->has_dictionary_ = true; + new_path->matcher_ = matcher_; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } } - } } PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, ROOT_); + return get_path_vec(output, ROOT_); } PathTrie* PathTrie::get_path_vec(std::vector& output, int stop, size_t max_steps) { - if (character == stop || character == ROOT_ || output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - return this; - } else { - output.push_back(character); - return parent->get_path_vec(output, stop, max_steps); - } + if (character == stop || character == ROOT_ || output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + return this; + } else { + output.push_back(character); + return parent->get_path_vec(output, stop, max_steps); + } } void PathTrie::iterate_to_vec(std::vector& output) { - if (exists_) { - log_prob_b_prev = log_prob_b_cur; - log_prob_nb_prev = log_prob_nb_cur; + if (exists_) { + log_prob_b_prev = log_prob_b_cur; + log_prob_nb_prev = log_prob_nb_cur; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; - score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - output.push_back(this); - } - for (auto child : children_) { - child.second->iterate_to_vec(output); - } + score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + output.push_back(this); + } + for (auto child : children_) { + child.second->iterate_to_vec(output); + } } void PathTrie::remove() { - exists_ = false; - - if (children_.size() == 0) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; - } - } + exists_ = false; + + if (children_.size() == 0) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); + child != parent->children_.end(); + ++child) { + if (child->first == character) { + parent->children_.erase(child); + break; + } + } - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); - } + if (parent->children_.size() == 0 && !parent->exists_) { + parent->remove(); + } - delete this; - } + delete this; + } } void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - dictionary_ = dictionary; - dictionary_state_ = dictionary->Start(); - has_dictionary_ = true; + dictionary_ = dictionary; + dictionary_state_ = dictionary->Start(); + has_dictionary_ = true; } using FSTMATCH = fst::SortedMatcher; void PathTrie::set_matcher(std::shared_ptr matcher) { - matcher_ = matcher; + matcher_ = matcher; } diff --git a/deepspeech/decoders/swig/path_trie.h b/deepspeech/decoders/swig/path_trie.h index 3a5b71b7efaa7ff4338362e2eeb11e42d117c5cc..717d4b00435affc412e6910ef728358f9907bf0b 100644 --- a/deepspeech/decoders/swig/path_trie.h +++ b/deepspeech/decoders/swig/path_trie.h @@ -27,55 +27,56 @@ * finite-state transducer for spelling correction. */ class PathTrie { -public: - PathTrie(); - ~PathTrie(); + public: + PathTrie(); + ~PathTrie(); - // get new prefix after appending new char - PathTrie* get_path_trie(int new_char, bool reset = true); + // get new prefix after appending new char + PathTrie* get_path_trie(int new_char, bool reset = true); - // get the prefix in index from root to current node - PathTrie* get_path_vec(std::vector& output); + // get the prefix in index from root to current node + PathTrie* get_path_vec(std::vector& output); - // get the prefix in index from some stop node to current nodel - PathTrie* get_path_vec(std::vector& output, - int stop, - size_t max_steps = std::numeric_limits::max()); + // get the prefix in index from some stop node to current nodel + PathTrie* get_path_vec( + std::vector& output, + int stop, + size_t max_steps = std::numeric_limits::max()); - // update log probs - void iterate_to_vec(std::vector& output); + // update log probs + void iterate_to_vec(std::vector& output); - // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); + // set dictionary for FST + void set_dictionary(fst::StdVectorFst* dictionary); - void set_matcher(std::shared_ptr>); + void set_matcher(std::shared_ptr>); - bool is_empty() { return ROOT_ == character; } + bool is_empty() { return ROOT_ == character; } - // remove current path from root - void remove(); + // remove current path from root + void remove(); - float log_prob_b_prev; - float log_prob_nb_prev; - float log_prob_b_cur; - float log_prob_nb_cur; - float score; - float approx_ctc; - int character; - PathTrie* parent; + float log_prob_b_prev; + float log_prob_nb_prev; + float log_prob_b_cur; + float log_prob_nb_cur; + float score; + float approx_ctc; + int character; + PathTrie* parent; -private: - int ROOT_; - bool exists_; - bool has_dictionary_; + private: + int ROOT_; + bool exists_; + bool has_dictionary_; - std::vector> children_; + std::vector> children_; - // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; - // true if finding ars in FST - std::shared_ptr> matcher_; + // pointer to dictionary of FST + fst::StdVectorFst* dictionary_; + fst::StdVectorFst::StateId dictionary_state_; + // true if finding ars in FST + std::shared_ptr> matcher_; }; #endif // PATH_TRIE_H diff --git a/deepspeech/decoders/swig/scorer.cpp b/deepspeech/decoders/swig/scorer.cpp index 497a289c22ad0e018c6055d696bcd5b46ec343a4..a25382b15cd73e2743d28b8e3e93d167d976fe2e 100644 --- a/deepspeech/decoders/swig/scorer.cpp +++ b/deepspeech/decoders/swig/scorer.cpp @@ -31,214 +31,214 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path, const std::vector& vocab_list) { - this->alpha = alpha; - this->beta = beta; + this->alpha = alpha; + this->beta = beta; - dictionary = nullptr; - is_character_based_ = true; - language_model_ = nullptr; + dictionary = nullptr; + is_character_based_ = true; + language_model_ = nullptr; - max_order_ = 0; - dict_size_ = 0; - SPACE_ID_ = -1; + max_order_ = 0; + dict_size_ = 0; + SPACE_ID_ = -1; - setup(lm_path, vocab_list); + setup(lm_path, vocab_list); } Scorer::~Scorer() { - if (language_model_ != nullptr) { - delete static_cast(language_model_); - } - if (dictionary != nullptr) { - delete static_cast(dictionary); - } + if (language_model_ != nullptr) { + delete static_cast(language_model_); + } + if (dictionary != nullptr) { + delete static_cast(dictionary); + } } void Scorer::setup(const std::string& lm_path, const std::vector& vocab_list) { - // load language model - load_lm(lm_path); - // set char map for scorer - set_char_map(vocab_list); - // fill the dictionary for FST - if (!is_character_based()) { - fill_dictionary(true); - } + // load language model + load_lm(lm_path); + // set char map for scorer + set_char_map(vocab_list); + // fill the dictionary for FST + if (!is_character_based()) { + fill_dictionary(true); + } } void Scorer::load_lm(const std::string& lm_path) { - const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); - - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - language_model_ = lm::ngram::LoadVirtual(filename, config); - max_order_ = static_cast(language_model_)->Order(); - vocabulary_ = enumerate.vocabulary; - for (size_t i = 0; i < vocabulary_.size(); ++i) { - if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && - vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && - get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - is_character_based_ = false; + const char* filename = lm_path.c_str(); + VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); + + RetriveStrEnumerateVocab enumerate; + lm::ngram::Config config; + config.enumerate_vocab = &enumerate; + language_model_ = lm::ngram::LoadVirtual(filename, config); + max_order_ = static_cast(language_model_)->Order(); + vocabulary_ = enumerate.vocabulary; + for (size_t i = 0; i < vocabulary_.size(); ++i) { + if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && + vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && + get_utf8_str_len(enumerate.vocabulary[i]) > 1) { + is_character_based_ = false; + } } - } } double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(language_model_); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; + lm::base::Model* model = static_cast(language_model_); + double cond_prob; + lm::ngram::State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV + if (word_index == 0) { + return OOV_SCORE; + } + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; + state = out_state; + out_state = tmp_state; } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; - } - // return log10 prob - return cond_prob; + // return log10 prob + return cond_prob; } double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < max_order_; ++i) { + sentence.push_back(START_TOKEN); + } + } else { + for (size_t i = 0; i < max_order_ - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); } double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > max_order_); - double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); - } - return score; + assert(words.size() > max_order_); + double score = 0.0; + for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { + std::vector ngram(words.begin() + i, + words.begin() + i + max_order_); + score += get_log_cond_prob(ngram); + } + return score; } void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; + this->alpha = alpha; + this->beta = beta; } std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += char_list_[ind]; - } - return word; + std::string word; + for (auto ind : input) { + word += char_list_[ind]; + } + return word; } std::vector Scorer::split_labels(const std::vector& labels) { - if (labels.empty()) return {}; - - std::string s = vec2str(labels); - std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; + if (labels.empty()) return {}; + + std::string s = vec2str(labels); + std::vector words; + if (is_character_based_) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } + return words; } void Scorer::set_char_map(const std::vector& char_list) { - char_list_ = char_list; - char_map_.clear(); - - // Set the char map for the FST for spelling correction - for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == " ") { - SPACE_ID_ = i; + char_list_ = char_list; + char_map_.clear(); + + // Set the char map for the FST for spelling correction + for (size_t i = 0; i < char_list_.size(); i++) { + if (char_list_[i] == " ") { + SPACE_ID_ = i; + } + // The initial state of FST is state 0, hence the index of chars in + // the FST should start from 1 to avoid the conflict with the initial + // state, otherwise wrong decoding results would be given. + char_map_[char_list_[i]] = i + 1; } - // The initial state of FST is state 0, hence the index of chars in - // the FST should start from 1 to avoid the conflict with the initial - // state, otherwise wrong decoding results would be given. - char_map_[char_list_[i]] = i + 1; - } } std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; - - for (int order = 0; order < max_order_; order++) { - std::vector prefix_vec; - - if (is_character_based_) { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); - current_node = new_node->parent; // Skipping spaces + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; + + for (int order = 0; order < max_order_; order++) { + std::vector prefix_vec; + + if (is_character_based_) { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); + current_node = new_node; + } else { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); + current_node = new_node->parent; // Skipping spaces + } + + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); + + if (new_node->character == -1) { + // No more spaces, but still need order + for (int i = 0; i < max_order_ - order - 1; i++) { + ngram.push_back(START_TOKEN); + } + break; + } } - - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); - - if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; - } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; + std::reverse(ngram.begin(), ngram.end()); + return ngram; } void Scorer::fill_dictionary(bool add_space) { - fst::StdVectorFst dictionary; - // For each unigram convert to ints and put in trie - int dict_size = 0; - for (const auto& word : vocabulary_) { - bool added = add_word_to_dictionary( - word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); - dict_size += added ? 1 : 0; - } - - dict_size_ = dict_size; - - /* Simplify FST - - * This gets rid of "epsilon" transitions in the FST. - * These are transitions that don't require a string input to be taken. - * Getting rid of them is necessary to make the FST determinisitc, but - * can greatly increase the size of the FST - */ - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; - - /* This makes the FST deterministic, meaning for any string input there's - * only one possible state the FST could be in. It is assumed our - * dictionary is deterministic when using it. - * (lest we'd have to check for multiple transitions at each state) - */ - fst::Determinize(dictionary, new_dict); - - /* Finds the simplest equivalent fst. This is unnecessary but decreases - * memory usage of the dictionary - */ - fst::Minimize(new_dict); - this->dictionary = new_dict; + fst::StdVectorFst dictionary; + // For each unigram convert to ints and put in trie + int dict_size = 0; + for (const auto& word : vocabulary_) { + bool added = add_word_to_dictionary( + word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); + dict_size += added ? 1 : 0; + } + + dict_size_ = dict_size; + + /* Simplify FST + + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST determinisitc, but + * can greatly increase the size of the FST + */ + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + /* This makes the FST deterministic, meaning for any string input there's + * only one possible state the FST could be in. It is assumed our + * dictionary is deterministic when using it. + * (lest we'd have to check for multiple transitions at each state) + */ + fst::Determinize(dictionary, new_dict); + + /* Finds the simplest equivalent fst. This is unnecessary but decreases + * memory usage of the dictionary + */ + fst::Minimize(new_dict); + this->dictionary = new_dict; } diff --git a/deepspeech/decoders/swig/scorer.h b/deepspeech/decoders/swig/scorer.h index 66c4cb123eba92b99e4e1add3cf071a73088734d..3f3001e77c222af41ee60b57d6f0bc38e81cb9ef 100644 --- a/deepspeech/decoders/swig/scorer.h +++ b/deepspeech/decoders/swig/scorer.h @@ -34,14 +34,14 @@ const std::string END_TOKEN = ""; // Implement a callback to retrive the dictionary of language model. class RetriveStrEnumerateVocab : public lm::EnumerateVocab { -public: - RetriveStrEnumerateVocab() {} + public: + RetriveStrEnumerateVocab() {} - void Add(lm::WordIndex index, const StringPiece &str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } + void Add(lm::WordIndex index, const StringPiece &str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } - std::vector vocabulary; + std::vector vocabulary; }; /* External scorer to query score for n-gram or sentence, including language @@ -53,74 +53,74 @@ public: * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); */ class Scorer { -public: - Scorer(double alpha, - double beta, - const std::string &lm_path, - const std::vector &vocabulary); - ~Scorer(); + public: + Scorer(double alpha, + double beta, + const std::string &lm_path, + const std::vector &vocabulary); + ~Scorer(); - double get_log_cond_prob(const std::vector &words); + double get_log_cond_prob(const std::vector &words); - double get_sent_log_prob(const std::vector &words); + double get_sent_log_prob(const std::vector &words); - // return the max order - size_t get_max_order() const { return max_order_; } + // return the max order + size_t get_max_order() const { return max_order_; } - // return the dictionary size of language model - size_t get_dict_size() const { return dict_size_; } + // return the dictionary size of language model + size_t get_dict_size() const { return dict_size_; } - // retrun true if the language model is character based - bool is_character_based() const { return is_character_based_; } + // retrun true if the language model is character based + bool is_character_based() const { return is_character_based_; } - // reset params alpha & beta - void reset_params(float alpha, float beta); + // reset params alpha & beta + void reset_params(float alpha, float beta); - // make ngram for a given prefix - std::vector make_ngram(PathTrie *prefix); + // make ngram for a given prefix + std::vector make_ngram(PathTrie *prefix); - // trransform the labels in index to the vector of words (word based lm) or - // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); + // trransform the labels in index to the vector of words (word based lm) or + // the vector of characters (character based lm) + std::vector split_labels(const std::vector &labels); - // language model weight - double alpha; - // word insertion weight - double beta; + // language model weight + double alpha; + // word insertion weight + double beta; - // pointer to the dictionary of FST - void *dictionary; + // pointer to the dictionary of FST + void *dictionary; -protected: - // necessary setup: load language model, set char map, fill FST's dictionary - void setup(const std::string &lm_path, - const std::vector &vocab_list); + protected: + // necessary setup: load language model, set char map, fill FST's dictionary + void setup(const std::string &lm_path, + const std::vector &vocab_list); - // load language model from given path - void load_lm(const std::string &lm_path); + // load language model from given path + void load_lm(const std::string &lm_path); - // fill dictionary for FST - void fill_dictionary(bool add_space); + // fill dictionary for FST + void fill_dictionary(bool add_space); - // set char map - void set_char_map(const std::vector &char_list); + // set char map + void set_char_map(const std::vector &char_list); - double get_log_prob(const std::vector &words); + double get_log_prob(const std::vector &words); - // translate the vector in index to string - std::string vec2str(const std::vector &input); + // translate the vector in index to string + std::string vec2str(const std::vector &input); -private: - void *language_model_; - bool is_character_based_; - size_t max_order_; - size_t dict_size_; + private: + void *language_model_; + bool is_character_based_; + size_t max_order_; + size_t dict_size_; - int SPACE_ID_; - std::vector char_list_; - std::unordered_map char_map_; + int SPACE_ID_; + std::vector char_list_; + std::unordered_map char_map_; - std::vector vocabulary_; + std::vector vocabulary_; }; #endif // SCORER_H_ diff --git a/third_party/pymmseg-cpp/bin/pymmseg b/third_party/pymmseg-cpp/bin/pymmseg index 706e59702702b9930e02454192adb2d13de81b32..5faa9985585d09f41f24816509e1c65015447af7 100755 --- a/third_party/pymmseg-cpp/bin/pymmseg +++ b/third_party/pymmseg-cpp/bin/pymmseg @@ -1,12 +1,12 @@ #!/usr/bin/env python3 - -import sys -import pstats import cProfile -from io import StringIO import getopt import os -from os.path import dirname, join +import pstats +import sys +from io import StringIO +from os.path import dirname +from os.path import join import mmseg diff --git a/third_party/python-pinyin/pinyin-data/CHANGELOG.md b/third_party/python-pinyin/pinyin-data/CHANGELOG.md index d65d4c15c7e158045f8f43666193c3b0160e6928..8a202b1915229362b142d4dd91cc6602b11a8c4b 100644 --- a/third_party/python-pinyin/pinyin-data/CHANGELOG.md +++ b/third_party/python-pinyin/pinyin-data/CHANGELOG.md @@ -94,7 +94,7 @@ * Update to the latest version of [Unihan Database](http://www.unicode.org/charts/unihan.html): - > Date: 2016-06-01 07:01:48 GMT [JHJ] + > Date: 2016-06-01 07:01:48 GMT [JHJ] > Unicode version: 9.0.0 diff --git a/third_party/python-pinyin/pinyin-data/README.md b/third_party/python-pinyin/pinyin-data/README.md index c954ec71d93f58ea13a1e0ae74b06fae927feda5..46a4ccee7fbcad013eab90725a196a9ac8cac637 100644 --- a/third_party/python-pinyin/pinyin-data/README.md +++ b/third_party/python-pinyin/pinyin-data/README.md @@ -19,7 +19,7 @@ [Unihan Database][unihan] 数据版本: -> Date: 2020-02-18 18:27:33 GMT [JHJ] +> Date: 2020-02-18 18:27:33 GMT [JHJ] > Unicode version: 13.0.0 * `kTGHZ2013.txt`: [Unihan Database][unihan] 中 [kTGHZ2013](http://www.unicode.org/reports/tr38/#kTGHZ2013) 部分的拼音数据(来源于《通用规范汉字字典》的拼音数据)