提交 3ee02039 编写于 作者: Y yangyaming

Refactor scorer and move utility functions to decoder_util.h

上级 d1189a79
......@@ -7,6 +7,8 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
Compiling for python interface requires swig, please make sure swig being installed.
Then run the setup
```shell
......
......@@ -9,29 +9,6 @@
typedef double log_prob_type;
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.second > b.second;
}
template <typename T>
T log_sum_exp(T x, T y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) {
// dimension check
......
......@@ -3,3 +3,10 @@
#include <cmath>
#include "decoder_utils.h"
size_t get_utf8_str_len(const std::string& str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
}
return str_len;
}
#ifndef DECODER_UTILS_H
#define DECODER_UTILS_H
#pragma once
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <utility>
/*
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b);
bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
{
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b);
bool pair_comp_second_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
{
return a.second > b.second;
}
template <typename T>
T log_sum_exp(const T &x, const T &y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
template <typename T> T log_sum_exp(T x, T y);
*/
#endif // DECODER_UTILS_H
......@@ -2,13 +2,15 @@
%{
#include "scorer.h"
#include "ctc_decoders.h"
#include "decoder_utils.h"
%}
%include "std_vector.i"
%include "std_pair.i"
%include "std_string.i"
%import "decoder_utils.h"
namespace std{
namespace std {
%template(DoubleVector) std::vector<double>;
%template(IntVector) std::vector<int>;
%template(StringVector) std::vector<std::string>;
......@@ -19,6 +21,9 @@ namespace std{
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
}
%import decoder_utils.h
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
%include "scorer.h"
%include "ctc_decoders.h"
#include <iostream>
#include <unistd.h>
#include "scorer.h"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha;
this->_beta = beta;
if (access(lm_model_path.c_str(), F_OK) != 0) {
std::cout<<"Invalid language model path!"<<std::endl;
exit(1);
}
this->_language_model = LoadVirtual(lm_model_path.c_str());
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->alpha = alpha;
this->beta = beta;
_is_character_based = true;
_language_model = nullptr;
_max_order = 0;
// load language model
load_LM(lm_path.c_str());
}
Scorer::~Scorer(){
delete (lm::base::Model *)this->_language_model;
Scorer::~Scorer() {
if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model);
}
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
}
for (int i=str.size()-1; i>=0; i--) {
if (str[i] == ch) {
end --;
} else {
break;
void Scorer::load_LM(const char* filename) {
if (access(filename, F_OK) != 0) {
std::cerr << "Invalid language model file !!!" << std::endl;
exit(1);
}
RetriveStrEnumerateVocab enumerate;
Config config;
config.enumerate_vocab = &enumerate;
_language_model = lm::ngram::LoadVirtual(filename, config);
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
_vocabulary = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) {
if (_is_character_based
&& _vocabulary[i] != UNK_TOKEN
&& _vocabulary[i] != START_TOKEN
&& _vocabulary[i] != END_TOKEN
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false;
}
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
}
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
double cond_prob;
State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCOER;
}
}
return cnt;
}
double Scorer::language_model_score(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
}
//log10 prob
double log_prob = ret.prob;
return log_prob;
// log10 prob
return cond_prob;
}
void Scorer::reset_params(float alpha, float beta) {
this->_alpha = alpha;
this->_beta = beta;
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < _max_order; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < _max_order - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
double Scorer::get_score(std::string sentence, bool log) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);
double final_score = 0.0;
if (log == false) {
final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
} else {
final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt);
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > _max_order);
double score = 0.0;
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + _max_order);
score += get_log_cond_prob(ngram);
}
return final_score;
return score;
}
......@@ -2,35 +2,58 @@
#define SCORER_H_
#include <string>
#include <memory>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
/* External scorer to evaluate a prefix or a complete sentence
* when a new word appended during decoding, consisting of word
* count and language model scoring.
const double OOV_SCOER = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
* Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score");
*/
class Scorer{
private:
float _alpha;
float _beta;
void *_language_model;
// Implement a callback to retrive string vocabulary.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
// word insertion term
int word_count(std::string);
// n-gram language model scoring
double language_model_score(std::string);
void Add(lm::WordIndex index, const StringPiece& str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
// External scorer to query languange score for n-gram or sentence.
// Example:
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{
public:
Scorer(){}
Scorer(float alpha, float beta, std::string lm_model_path);
Scorer(double alpha, double beta, const std::string& lm_path);
~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words);
double get_sent_log_prob(const std::vector<std::string>& words);
size_t get_max_order() { return _max_order; }
bool is_character_based() { return _is_character_based; }
std::vector<std::string> get_vocab() { return _vocabulary; }
// expose to decoder
double alpha;
double beta;
// reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string, bool log=false);
protected:
void load_LM(const char* filename);
double get_log_prob(const std::vector<std::string>& words);
private:
void* _language_model;
bool _is_character_based;
size_t _max_order;
std::vector<std::string> _vocabulary;
};
#endif //SCORER_H_
#endif // SCORER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册