提交 eef364d1 编写于 作者: Y Yibing Liu

adapt to the last three commits

上级 8dc0b2b0
...@@ -14,7 +14,7 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz ...@@ -14,7 +14,7 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz
``` ```
- [**swig**]: Compiling for python interface requires swig, please make sure swig being installed. - [**SWIG**](http://www.swig.org): Compiling for python interface requires swig, please make sure swig being installed.
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool - [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool
......
...@@ -3,9 +3,13 @@ ...@@ -3,9 +3,13 @@
#include "lm/config.hh" #include "lm/config.hh"
#include "lm/state.hh" #include "lm/state.hh"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
#include "scorer.h" #include "scorer.h"
#include "decoder_utils.h" #include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
...@@ -90,3 +94,84 @@ double Scorer::get_log_prob(const std::vector<std::string>& words) { ...@@ -90,3 +94,84 @@ double Scorer::get_log_prob(const std::vector<std::string>& words) {
} }
return score; return score;
} }
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
}
for (int i=str.size()-1; i>=0; i--) {
if (str[i] == ch) {
end --;
} else {
break;
}
}
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
}
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
}
}
return cnt;
}
double Scorer::get_log_cond_prob(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state;
}
//log10 prob
double log_prob = ret.prob;
return log_prob;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
double Scorer::get_score(std::string sentence, bool log) {
double lm_score = get_log_cond_prob(sentence);
int word_cnt = word_count(sentence);
double final_score = 0.0;
if (log == false) {
final_score = pow(10, alpha * lm_score) * pow(word_cnt, beta);
} else {
final_score = alpha * lm_score * std::log(10)
+ beta * std::log(word_cnt);
}
return final_score;
}
...@@ -30,6 +30,7 @@ public: ...@@ -30,6 +30,7 @@ public:
// Example: // Example:
// Scorer scorer(alpha, beta, "path_of_language_model"); // Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_log_cond_prob("this a sentence");
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{ class Scorer{
public: public:
...@@ -40,7 +41,14 @@ public: ...@@ -40,7 +41,14 @@ public:
size_t get_max_order() { return _max_order; } size_t get_max_order() { return _max_order; }
bool is_character_based() { return _is_character_based; } bool is_character_based() { return _is_character_based; }
std::vector<std::string> get_vocab() { return _vocabulary; } std::vector<std::string> get_vocab() { return _vocabulary; }
// word insertion term
int word_count(std::string);
// get the log cond prob of the last word
double get_log_cond_prob(std::string);
// reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string, bool log=false);
// expose to decoder // expose to decoder
double alpha; double alpha;
double beta; double beta;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册