#include "scorer.h" #include #include #include "lm/config.hh" #include "lm/model.hh" #include "lm/state.hh" #include "util/string_piece.hh" #include "util/tokenize_piece.hh" #include "decoder_utils.h" using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path, const std::vector& vocab_list) { this->alpha = alpha; this->beta = beta; _is_character_based = true; _language_model = nullptr; dictionary = nullptr; _max_order = 0; _dict_size = 0; _SPACE_ID = -1; setup(lm_path, vocab_list); } Scorer::~Scorer() { if (_language_model != nullptr) { delete static_cast(_language_model); } if (dictionary != nullptr) { delete static_cast(dictionary); } } void Scorer::setup(const std::string& lm_path, const std::vector& vocab_list) { // load language model load_lm(lm_path); // set char map for scorer set_char_map(vocab_list); // fill the dictionary for FST if (!is_character_based()) { fill_dictionary(true); } } void Scorer::load_lm(const std::string& lm_path) { const char* filename = lm_path.c_str(); VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); RetriveStrEnumerateVocab enumerate; lm::ngram::Config config; config.enumerate_vocab = &enumerate; _language_model = lm::ngram::LoadVirtual(filename, config); _max_order = static_cast(_language_model)->Order(); _vocabulary = enumerate.vocabulary; for (size_t i = 0; i < _vocabulary.size(); ++i) { if (_is_character_based && _vocabulary[i] != UNK_TOKEN && _vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN && get_utf8_str_len(enumerate.vocabulary[i]) > 1) { _is_character_based = false; } } } double Scorer::get_log_cond_prob(const std::vector& words) { lm::base::Model* model = static_cast(_language_model); double cond_prob; lm::ngram::State state, tmp_state, out_state; // avoid to inserting in begin model->NullContextWrite(&state); for (size_t i = 0; i < words.size(); ++i) { lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); // encounter OOV if (word_index == 0) { return OOV_SCORE; } cond_prob = model->BaseScore(&state, word_index, &out_state); tmp_state = state; state = out_state; out_state = tmp_state; } // return log10 prob return cond_prob; } double Scorer::get_sent_log_prob(const std::vector& words) { std::vector sentence; if (words.size() == 0) { for (size_t i = 0; i < _max_order; ++i) { sentence.push_back(START_TOKEN); } } else { for (size_t i = 0; i < _max_order - 1; ++i) { sentence.push_back(START_TOKEN); } sentence.insert(sentence.end(), words.begin(), words.end()); } sentence.push_back(END_TOKEN); return get_log_prob(sentence); } double Scorer::get_log_prob(const std::vector& words) { assert(words.size() > _max_order); double score = 0.0; for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { std::vector ngram(words.begin() + i, words.begin() + i + _max_order); score += get_log_cond_prob(ngram); } return score; } void Scorer::reset_params(float alpha, float beta) { this->alpha = alpha; this->beta = beta; } std::string Scorer::vec2str(const std::vector& input) { std::string word; for (auto ind : input) { word += _char_list[ind]; } return word; } std::vector Scorer::split_labels(const std::vector& labels) { if (labels.empty()) return {}; std::string s = vec2str(labels); std::vector words; if (_is_character_based) { words = split_utf8_str(s); } else { words = split_str(s, " "); } return words; } void Scorer::set_char_map(const std::vector& char_list) { _char_list = char_list; _char_map.clear(); for (unsigned int i = 0; i < _char_list.size(); i++) { if (_char_list[i] == " ") { _SPACE_ID = i; _char_map[' '] = i; } else if (_char_list[i].size() == 1) { _char_map[_char_list[i][0]] = i; } } } std::vector Scorer::make_ngram(PathTrie* prefix) { std::vector ngram; PathTrie* current_node = prefix; PathTrie* new_node = nullptr; for (int order = 0; order < _max_order; order++) { std::vector prefix_vec; if (_is_character_based) { new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); current_node = new_node; } else { new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); current_node = new_node->parent; // Skipping spaces } // reconstruct word std::string word = vec2str(prefix_vec); ngram.push_back(word); if (new_node->character == -1) { // No more spaces, but still need order for (int i = 0; i < _max_order - order - 1; i++) { ngram.push_back(START_TOKEN); } break; } } std::reverse(ngram.begin(), ngram.end()); return ngram; } void Scorer::fill_dictionary(bool add_space) { fst::StdVectorFst dictionary; // First reverse char_list so ints can be accessed by chars std::unordered_map char_map; for (unsigned int i = 0; i < _char_list.size(); i++) { char_map[_char_list[i]] = i; } // For each unigram convert to ints and put in trie int dict_size = 0; for (const auto& word : _vocabulary) { bool added = add_word_to_dictionary( word, char_map, add_space, _SPACE_ID, &dictionary); dict_size += added ? 1 : 0; } _dict_size = dict_size; /* Simplify FST * This gets rid of "epsilon" transitions in the FST. * These are transitions that don't require a string input to be taken. * Getting rid of them is necessary to make the FST determinisitc, but * can greatly increase the size of the FST */ fst::RmEpsilon(&dictionary); fst::StdVectorFst* new_dict = new fst::StdVectorFst; /* This makes the FST deterministic, meaning for any string input there's * only one possible state the FST could be in. It is assumed our * dictionary is deterministic when using it. * (lest we'd have to check for multiple transitions at each state) */ fst::Determinize(dictionary, new_dict); /* Finds the simplest equivalent fst. This is unnecessary but decreases * memory usage of the dictionary */ fst::Minimize(new_dict); this->dictionary = new_dict; }