提交 09f4c6e1 编写于 作者: Y Yibing Liu

remove unused functions in Scorer

上级 20d13a4d
...@@ -96,13 +96,13 @@ std::vector<std::pair<double, std::string> > ...@@ -96,13 +96,13 @@ std::vector<std::pair<double, std::string> >
prefixes.push_back(&root); prefixes.push_back(&root);
if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) { if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer->_dictionary == nullptr) { if (ext_scorer->dictionary == nullptr) {
// TODO: init dictionary // TODO: init dictionary
ext_scorer->set_char_map(vocabulary); ext_scorer->set_char_map(vocabulary);
// add_space should be true? // add_space should be true?
ext_scorer->fill_dictionary(true); ext_scorer->fill_dictionary(true);
} }
auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->_dictionary); auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr); root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
...@@ -285,7 +285,7 @@ std::vector<std::vector<std::pair<double, std::string> > > ...@@ -285,7 +285,7 @@ std::vector<std::vector<std::pair<double, std::string> > >
// dictionary init // dictionary init
if ( ext_scorer != nullptr if ( ext_scorer != nullptr
&& !ext_scorer->is_character_based() && !ext_scorer->is_character_based()
&& ext_scorer->_dictionary == nullptr) { && ext_scorer->dictionary == nullptr) {
// init dictionary // init dictionary
ext_scorer->set_char_map(vocabulary); ext_scorer->set_char_map(vocabulary);
ext_scorer->fill_dictionary(true); ext_scorer->fill_dictionary(true);
......
...@@ -15,7 +15,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { ...@@ -15,7 +15,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->beta = beta; this->beta = beta;
_is_character_based = true; _is_character_based = true;
_language_model = nullptr; _language_model = nullptr;
_dictionary = nullptr; dictionary = nullptr;
_max_order = 0; _max_order = 0;
_SPACE_ID = -1; _SPACE_ID = -1;
// load language model // load language model
...@@ -25,8 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { ...@@ -25,8 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
Scorer::~Scorer() { Scorer::~Scorer() {
if (_language_model != nullptr) if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model); delete static_cast<lm::base::Model*>(_language_model);
if (_dictionary != nullptr) if (dictionary != nullptr)
delete static_cast<fst::StdVectorFst*>(_dictionary); delete static_cast<fst::StdVectorFst*>(dictionary);
} }
void Scorer::load_LM(const char* filename) { void Scorer::load_LM(const char* filename) {
...@@ -99,87 +99,11 @@ double Scorer::get_log_prob(const std::vector<std::string>& words) { ...@@ -99,87 +99,11 @@ double Scorer::get_log_prob(const std::vector<std::string>& words) {
return score; return score;
} }
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
}
for (int i=str.size()-1; i>=0; i--) {
if (str[i] == ch) {
end --;
} else {
break;
}
}
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
}
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
}
}
return cnt;
}
double Scorer::get_log_cond_prob(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state;
}
//log10 prob
double log_prob = ret.prob;
return log_prob;
}
void Scorer::reset_params(float alpha, float beta) { void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
} }
double Scorer::get_score(std::string sentence, bool log) {
double lm_score = get_log_cond_prob(sentence);
int word_cnt = word_count(sentence);
double final_score = 0.0;
if (log == false) {
final_score = pow(10, alpha * lm_score) * pow(word_cnt, beta);
} else {
final_score = alpha * lm_score * std::log(10)
+ beta * std::log(word_cnt);
}
return final_score;
}
std::string Scorer::vec2str(const std::vector<int>& input) { std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word; std::string word;
for (auto ind : input) { for (auto ind : input) {
...@@ -188,7 +112,6 @@ std::string Scorer::vec2str(const std::vector<int>& input) { ...@@ -188,7 +112,6 @@ std::string Scorer::vec2str(const std::vector<int>& input) {
return word; return word;
} }
std::vector<std::string> std::vector<std::string>
Scorer::split_labels(const std::vector<int> &labels) { Scorer::split_labels(const std::vector<int> &labels) {
if (labels.empty()) if (labels.empty())
...@@ -291,6 +214,6 @@ void Scorer::fill_dictionary(bool add_space) { ...@@ -291,6 +214,6 @@ void Scorer::fill_dictionary(bool add_space) {
// Finds the simplest equivalent fst. This is unnecessary but decreases // Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary // memory usage of the dictionary
fst::Minimize(new_dict); fst::Minimize(new_dict);
_dictionary = new_dict; this->dictionary = new_dict;
} }
...@@ -42,15 +42,8 @@ public: ...@@ -42,15 +42,8 @@ public:
double get_sent_log_prob(const std::vector<std::string>& words); double get_sent_log_prob(const std::vector<std::string>& words);
size_t get_max_order() { return _max_order; } size_t get_max_order() { return _max_order; }
bool is_character_based() { return _is_character_based; } bool is_character_based() { return _is_character_based; }
std::vector<std::string> get_vocab() { return _vocabulary; }
// word insertion term
int word_count(std::string);
// get the log cond prob of the last word
double get_log_cond_prob(std::string);
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string, bool log=false);
// make ngram // make ngram
std::vector<std::string> make_ngram(PathTrie* prefix); std::vector<std::string> make_ngram(PathTrie* prefix);
// fill dictionary for fst // fill dictionary for fst
...@@ -61,7 +54,7 @@ public: ...@@ -61,7 +54,7 @@ public:
double alpha; double alpha;
double beta; double beta;
// fst dictionary // fst dictionary
void* _dictionary; void* dictionary;
protected: protected:
void load_LM(const char* filename); void load_LM(const char* filename);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册