提交 3ee02039 编写于 作者: Y yangyaming

Refactor scorer and move utility functions to decoder_util.h

上级 d1189a79
...@@ -7,6 +7,8 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz ...@@ -7,6 +7,8 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz
``` ```
Compiling for python interface requires swig, please make sure swig being installed.
Then run the setup Then run the setup
```shell ```shell
......
...@@ -9,29 +9,6 @@ ...@@ -9,29 +9,6 @@
typedef double log_prob_type; typedef double log_prob_type;
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.second > b.second;
}
template <typename T>
T log_sum_exp(T x, T y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) { std::vector<std::string> vocabulary) {
// dimension check // dimension check
......
...@@ -3,3 +3,10 @@ ...@@ -3,3 +3,10 @@
#include <cmath> #include <cmath>
#include "decoder_utils.h" #include "decoder_utils.h"
size_t get_utf8_str_len(const std::string& str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
}
return str_len;
}
#ifndef DECODER_UTILS_H #ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H #define DECODER_UTILS_H_
#pragma once
#include <utility> #include <utility>
/*
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b); bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
{
return a.first > b.first;
}
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b); bool pair_comp_second_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
{
return a.second > b.second;
}
template <typename T>
T log_sum_exp(const T &x, const T &y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
template <typename T> T log_sum_exp(T x, T y);
*/
#endif // DECODER_UTILS_H #endif // DECODER_UTILS_H
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
%{ %{
#include "scorer.h" #include "scorer.h"
#include "ctc_decoders.h" #include "ctc_decoders.h"
#include "decoder_utils.h"
%} %}
%include "std_vector.i" %include "std_vector.i"
%include "std_pair.i" %include "std_pair.i"
%include "std_string.i" %include "std_string.i"
%import "decoder_utils.h"
namespace std{ namespace std {
%template(DoubleVector) std::vector<double>; %template(DoubleVector) std::vector<double>;
%template(IntVector) std::vector<int>; %template(IntVector) std::vector<int>;
%template(StringVector) std::vector<std::string>; %template(StringVector) std::vector<std::string>;
...@@ -19,6 +21,9 @@ namespace std{ ...@@ -19,6 +21,9 @@ namespace std{
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >; %template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
} }
%import decoder_utils.h %template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
%include "scorer.h" %include "scorer.h"
%include "ctc_decoders.h" %include "ctc_decoders.h"
#include <iostream> #include <iostream>
#include <unistd.h> #include <unistd.h>
#include "scorer.h" #include "scorer.h"
#include "lm/model.hh" #include "decoder_utils.h"
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->alpha = alpha;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { this->beta = beta;
this->_alpha = alpha; _is_character_based = true;
this->_beta = beta; _language_model = nullptr;
_max_order = 0;
if (access(lm_model_path.c_str(), F_OK) != 0) { // load language model
std::cout<<"Invalid language model path!"<<std::endl; load_LM(lm_path.c_str());
exit(1);
}
this->_language_model = LoadVirtual(lm_model_path.c_str());
} }
Scorer::~Scorer(){ Scorer::~Scorer() {
delete (lm::base::Model *)this->_language_model; if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model);
} }
/* Strip a input sentence void Scorer::load_LM(const char* filename) {
* Parameters: if (access(filename, F_OK) != 0) {
* str: A reference to the objective string std::cerr << "Invalid language model file !!!" << std::endl;
* ch: The character to prune exit(1);
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
} }
for (int i=str.size()-1; i>=0; i--) { RetriveStrEnumerateVocab enumerate;
if (str[i] == ch) { Config config;
end --; config.enumerate_vocab = &enumerate;
} else { _language_model = lm::ngram::LoadVirtual(filename, config);
break; _max_order = static_cast<lm::base::Model*>(_language_model)->Order();
_vocabulary = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) {
if (_is_character_based
&& _vocabulary[i] != UNK_TOKEN
&& _vocabulary[i] != START_TOKEN
&& _vocabulary[i] != END_TOKEN
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false;
} }
} }
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
} }
int Scorer::word_count(std::string sentence) { double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
strip(sentence); lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
int cnt = 1; double cond_prob;
for (int i=0; i<sentence.size(); i++) { State state, tmp_state, out_state;
if (sentence[i] == ' ' && sentence[i-1] != ' ') { // avoid to inserting <s> in begin
cnt ++; model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCOER;
} }
} cond_prob = model->BaseScore(&state, word_index, &out_state);
return cnt; tmp_state = state;
}
double Scorer::language_model_score(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state; state = out_state;
out_state = tmp_state;
} }
//log10 prob // log10 prob
double log_prob = ret.prob; return cond_prob;
return log_prob;
} }
void Scorer::reset_params(float alpha, float beta) { double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
this->_alpha = alpha; std::vector<std::string> sentence;
this->_beta = beta; if (words.size() == 0) {
for (size_t i = 0; i < _max_order; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < _max_order - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
} }
double Scorer::get_score(std::string sentence, bool log) { double Scorer::get_log_prob(const std::vector<std::string>& words) {
double lm_score = language_model_score(sentence); assert(words.size() > _max_order);
int word_cnt = word_count(sentence); double score = 0.0;
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
double final_score = 0.0; std::vector<std::string> ngram(words.begin() + i,
if (log == false) { words.begin() + i + _max_order);
final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); score += get_log_cond_prob(ngram);
} else {
final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt);
} }
return final_score; return score;
} }
...@@ -2,35 +2,58 @@ ...@@ -2,35 +2,58 @@
#define SCORER_H_ #define SCORER_H_
#include <string> #include <string>
#include <memory>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
/* External scorer to evaluate a prefix or a complete sentence const double OOV_SCOER = -1000.0;
* when a new word appended during decoding, consisting of word const std::string START_TOKEN = "<s>";
* count and language model scoring. const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
* Example: // Implement a callback to retrive string vocabulary.
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm"); class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
* double score = ext_scorer.get_score("sentence_to_score"); public:
*/ RetriveStrEnumerateVocab() {}
class Scorer{
private:
float _alpha;
float _beta;
void *_language_model;
// word insertion term void Add(lm::WordIndex index, const StringPiece& str) {
int word_count(std::string); vocabulary.push_back(std::string(str.data(), str.length()));
// n-gram language model scoring }
double language_model_score(std::string);
std::vector<std::string> vocabulary;
};
// External scorer to query languange score for n-gram or sentence.
// Example:
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{
public: public:
Scorer(){} Scorer(double alpha, double beta, const std::string& lm_path);
Scorer(float alpha, float beta, std::string lm_model_path);
~Scorer(); ~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words);
double get_sent_log_prob(const std::vector<std::string>& words);
size_t get_max_order() { return _max_order; }
bool is_character_based() { return _is_character_based; }
std::vector<std::string> get_vocab() { return _vocabulary; }
// expose to decoder
double alpha;
double beta;
// reset params alpha & beta protected:
void reset_params(float alpha, float beta); void load_LM(const char* filename);
// get the final score double get_log_prob(const std::vector<std::string>& words);
double get_score(std::string, bool log=false);
private:
void* _language_model;
bool _is_character_based;
size_t _max_order;
std::vector<std::string> _vocabulary;
}; };
#endif //SCORER_H_ #endif // SCORER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册