提交 5208b8e4 编写于 作者: Y Yibing Liu

format C++ source code

上级 a2ddfe8d
此差异已折叠。
#ifndef CTC_BEAM_SEARCH_DECODER_H_ #ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_
#include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "scorer.h" #include "scorer.h"
/* CTC Best Path Decoder /* CTC Best Path Decoder
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary); std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder /* CTC Beam Search Decoder
...@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, ...@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, std::vector<std::vector<double>> probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob=1.0, double cutoff_prob = 1.0,
int cutoff_top_n=40, int cutoff_top_n = 40,
Scorer *ext_scorer=NULL Scorer *ext_scorer = NULL);
);
/* CTC Beam Search Decoder for batch data, the interface is consistent with the /* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version. * original decoder in Python version.
...@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> > ...@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
* sample. * sample.
*/ */
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(std::vector<std::vector<std::vector<double>>> probs_split, ctc_beam_search_decoder_batch(
int beam_size, std::vector<std::vector<std::vector<double>>> probs_split,
std::vector<std::string> vocabulary, int beam_size,
int blank_id, std::vector<std::string> vocabulary,
int num_processes, int blank_id,
double cutoff_prob=1.0, int num_processes,
int cutoff_top_n=40, double cutoff_prob = 1.0,
Scorer *ext_scorer=NULL int cutoff_top_n = 40,
); Scorer *ext_scorer = NULL);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
#include <limits> #include "decoder_utils.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include "decoder_utils.h" #include <limits>
size_t get_utf8_str_len(const std::string& str) { size_t get_utf8_str_len(const std::string& str) {
size_t str_len = 0; size_t str_len = 0;
for (char c : str) { for (char c : str) {
str_len += ((c & 0xc0) != 0x80); str_len += ((c & 0xc0) != 0x80);
} }
return str_len; return str_len;
} }
std::vector<std::string> split_utf8_str(const std::string& str) std::vector<std::string> split_utf8_str(const std::string& str) {
{
std::vector<std::string> result; std::vector<std::string> result;
std::string out_str; std::string out_str;
for (char c : str) for (char c : str) {
if ((c & 0xc0) != 0x80) // new UTF-8 character
{ {
if ((c & 0xc0) != 0x80) //new UTF-8 character if (!out_str.empty()) {
{ result.push_back(out_str);
if (!out_str.empty()) out_str.clear();
{ }
result.push_back(out_str);
out_str.clear();
}
}
out_str.append(1, c);
} }
out_str.append(1, c);
}
result.push_back(out_str); result.push_back(out_str);
return result; return result;
} }
std::vector<std::string> split_str(const std::string &s, std::vector<std::string> split_str(const std::string& s,
const std::string &delim) { const std::string& delim) {
std::vector<std::string> result; std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size(); std::size_t start = 0, delim_len = delim.size();
while (true) { while (true) {
std::size_t end = s.find(delim, start); std::size_t end = s.find(delim, start);
if (end == std::string::npos) { if (end == std::string::npos) {
if (start < s.size()) { if (start < s.size()) {
result.push_back(s.substr(start)); result.push_back(s.substr(start));
} }
break; break;
} }
if (end > start) { if (end > start) {
result.push_back(s.substr(start, end - start)); result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
} }
return result; start = end + delim_len;
}
return result;
} }
bool prefix_compare(const PathTrie* x, const PathTrie* y) { bool prefix_compare(const PathTrie* x, const PathTrie* y) {
if (x->_score == y->_score) { if (x->score == y->score) {
if (x->_character == y->_character) { if (x->character == y->character) {
return false; return false;
} else {
return (x->_character < y->_character);
}
} else { } else {
return x->_score > y->_score; return (x->character < y->character);
} }
} else {
return x->score > y->score;
}
} }
void add_word_to_fst(const std::vector<int>& word, void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary) { fst::StdVectorFst* dictionary) {
if (dictionary->NumStates() == 0) { if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState(); fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0); assert(start == 0);
dictionary->SetStart(start); dictionary->SetStart(start);
} }
fst::StdVectorFst::StateId src = dictionary->Start(); fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst; fst::StdVectorFst::StateId dst;
for (auto c : word) { for (auto c : word) {
dst = dictionary->AddState(); dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst; src = dst;
} }
dictionary->SetFinal(dst, fst::StdArc::Weight::One()); dictionary->SetFinal(dst, fst::StdArc::Weight::One());
} }
bool add_word_to_dictionary(const std::string& word, bool add_word_to_dictionary(
const std::unordered_map<std::string, int>& char_map, const std::string& word,
bool add_space, const std::unordered_map<std::string, int>& char_map,
int SPACE_ID, bool add_space,
fst::StdVectorFst* dictionary) { int SPACE_ID,
auto characters = split_utf8_str(word); fst::StdVectorFst* dictionary) {
auto characters = split_utf8_str(word);
std::vector<int> int_word; std::vector<int> int_word;
for (auto& c : characters) { for (auto& c : characters) {
if (c == " ") { if (c == " ") {
int_word.push_back(SPACE_ID); int_word.push_back(SPACE_ID);
} else { } else {
auto int_c = char_map.find(c); auto int_c = char_map.find(c);
if (int_c != char_map.end()) { if (int_c != char_map.end()) {
int_word.push_back(int_c->second); int_word.push_back(int_c->second);
} else { } else {
return false; // return without adding return false; // return without adding
} }
}
} }
}
if (add_space) { if (add_space) {
int_word.push_back(SPACE_ID); int_word.push_back(SPACE_ID);
} }
add_word_to_fst(int_word, dictionary); add_word_to_fst(int_word, dictionary);
return true; return true;
} }
...@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min(); ...@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// Function template for comparing two pairs // Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) const std::pair<T1, T2> &b) {
{ return a.first > b.first;
return a.first > b.first;
} }
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) const std::pair<T1, T2> &b) {
{ return a.second > b.second;
return a.second > b.second;
} }
template <typename T> template <typename T>
T log_sum_exp(const T &x, const T &y) T log_sum_exp(const T &x, const T &y) {
{ static T num_min = -std::numeric_limits<T>::max();
static T num_min = -std::numeric_limits<T>::max(); if (x <= num_min) return y;
if (x <= num_min) return y; if (y <= num_min) return x;
if (y <= num_min) return x; T xmax = std::max(x, y);
T xmax = std::max(x, y); return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
} }
// Functor for prefix comparsion // Functor for prefix comparsion
bool prefix_compare(const PathTrie* x, const PathTrie* y); bool prefix_compare(const PathTrie *x, const PathTrie *y);
// Get length of utf8 encoding string // Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229 // See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str); size_t get_utf8_str_len(const std::string &str);
// Split a string into a list of strings on a given string // Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are // delimiter. NB: delimiters on beginning / end of string are
...@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s, ...@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
std::vector<std::string> split_utf8_str(const std::string &str); std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst // Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int>& word, void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst* dictionary); fst::StdVectorFst *dictionary);
// Add a word in string to dictionary // Add a word in string to dictionary
bool add_word_to_dictionary(const std::string& word, bool add_word_to_dictionary(
const std::unordered_map<std::string, int>& char_map, const std::string &word,
bool add_space, const std::unordered_map<std::string, int> &char_map,
int SPACE_ID, bool add_space,
fst::StdVectorFst* dictionary); int SPACE_ID,
#endif // DECODER_UTILS_H fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H
...@@ -4,145 +4,142 @@ ...@@ -4,145 +4,142 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "path_trie.h"
#include "decoder_utils.h" #include "decoder_utils.h"
#include "path_trie.h"
PathTrie::PathTrie() { PathTrie::PathTrie() {
_log_prob_b_prev = -NUM_FLT_INF; log_prob_b_prev = -NUM_FLT_INF;
_log_prob_nb_prev = -NUM_FLT_INF; log_prob_nb_prev = -NUM_FLT_INF;
_log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
_score = -NUM_FLT_INF; score = -NUM_FLT_INF;
_ROOT = -1; _ROOT = -1;
_character = _ROOT; character = _ROOT;
_exists = true; _exists = true;
_parent = nullptr; parent = nullptr;
_dictionary = nullptr; _dictionary = nullptr;
_dictionary_state = 0; _dictionary_state = 0;
_has_dictionary = false; _has_dictionary = false;
_matcher = nullptr; // finds arcs in FST _matcher = nullptr; // finds arcs in FST
} }
PathTrie::~PathTrie() { PathTrie::~PathTrie() {
for (auto child : _children) { for (auto child : _children) {
delete child.second; delete child.second;
} }
} }
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = _children.begin(); auto child = _children.begin();
for (child = _children.begin(); child != _children.end(); ++child) { for (child = _children.begin(); child != _children.end(); ++child) {
if (child->first == new_char) { if (child->first == new_char) {
break; break;
}
} }
if ( child != _children.end() ) { }
if (!child->second->_exists) { if (child != _children.end()) {
child->second->_exists = true; if (!child->second->_exists) {
child->second->_log_prob_b_prev = -NUM_FLT_INF; child->second->_exists = true;
child->second->_log_prob_nb_prev = -NUM_FLT_INF; child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->_log_prob_b_cur = -NUM_FLT_INF; child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->_log_prob_nb_cur = -NUM_FLT_INF; child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
} else {
if (_has_dictionary) {
_matcher->SetState(_dictionary_state);
bool found = _matcher->Find(new_char);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = _dictionary->Final(_dictionary_state);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
_dictionary_state = _dictionary->Start();
} }
return (child->second); return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->_dictionary = _dictionary;
new_path->_dictionary_state = _matcher->Value().nextstate;
new_path->_has_dictionary = true;
new_path->_matcher = _matcher;
_children.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else { } else {
if (_has_dictionary) { PathTrie* new_path = new PathTrie;
_matcher->SetState(_dictionary_state); new_path->character = new_char;
bool found = _matcher->Find(new_char); new_path->parent = this;
if (!found) { _children.push_back(std::make_pair(new_char, new_path));
// Adding this character causes word outside dictionary return new_path;
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = _dictionary->Final(_dictionary_state);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
_dictionary_state = _dictionary->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->_character = new_char;
new_path->_parent = this;
new_path->_dictionary = _dictionary;
new_path->_dictionary_state = _matcher->Value().nextstate;
new_path->_has_dictionary = true;
new_path->_matcher = _matcher;
_children.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->_character = new_char;
new_path->_parent = this;
_children.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} }
}
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) { PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, _ROOT); return get_path_vec(output, _ROOT);
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps) { size_t max_steps) {
if (_character == stop || if (character == stop || character == _ROOT || output.size() == max_steps) {
_character == _ROOT || std::reverse(output.begin(), output.end());
output.size() == max_steps) { return this;
std::reverse(output.begin(), output.end()); } else {
return this; output.push_back(character);
} else { return parent->get_path_vec(output, stop, max_steps);
output.push_back(_character); }
return _parent->get_path_vec(output, stop, max_steps);
}
} }
void PathTrie::iterate_to_vec( void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
std::vector<PathTrie*>& output) { if (_exists) {
if (_exists) { log_prob_b_prev = log_prob_b_cur;
_log_prob_b_prev = _log_prob_b_cur; log_prob_nb_prev = log_prob_nb_cur;
_log_prob_nb_prev = _log_prob_nb_cur;
_log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
_score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this); output.push_back(this);
} }
for (auto child : _children) { for (auto child : _children) {
child.second->iterate_to_vec(output); child.second->iterate_to_vec(output);
} }
} }
void PathTrie::remove() { void PathTrie::remove() {
_exists = false; _exists = false;
if (_children.size() == 0) { if (_children.size() == 0) {
auto child = _parent->_children.begin(); auto child = parent->_children.begin();
for (child = _parent->_children.begin(); for (child = parent->_children.begin(); child != parent->_children.end();
child != _parent->_children.end(); ++child) { ++child) {
if (child->first == _character) { if (child->first == character) {
_parent->_children.erase(child); parent->_children.erase(child);
break; break;
} }
} }
if ( _parent->_children.size() == 0 && !_parent->_exists ) {
_parent->remove();
}
delete this; if (parent->_children.size() == 0 && !parent->_exists) {
parent->remove();
} }
delete this;
}
} }
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
_dictionary = dictionary; _dictionary = dictionary;
_dictionary_state = dictionary->Start(); _dictionary_state = dictionary->Start();
_has_dictionary = true; _has_dictionary = true;
} }
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) { void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
_matcher = matcher; _matcher = matcher;
} }
#ifndef PATH_TRIE_H #ifndef PATH_TRIE_H
#define PATH_TRIE_H #define PATH_TRIE_H
#pragma once #pragma once
#include <fst/fstlib.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <fst/fstlib.h>
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
class PathTrie { class PathTrie {
public: public:
PathTrie(); PathTrie();
~PathTrie(); ~PathTrie();
PathTrie* get_path_trie(int new_char, bool reset = true);
PathTrie* get_path_vec(std::vector<int> &output); PathTrie* get_path_trie(int new_char, bool reset = true);
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(std::vector<int>& output);
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
void iterate_to_vec(std::vector<PathTrie*> &output); PathTrie* get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
void set_dictionary(fst::StdVectorFst* dictionary); void iterate_to_vec(std::vector<PathTrie*>& output);
void set_matcher(std::shared_ptr<FSTMATCH> matcher); void set_dictionary(fst::StdVectorFst* dictionary);
bool is_empty() { void set_matcher(std::shared_ptr<FSTMATCH> matcher);
return _ROOT == _character;
}
void remove(); bool is_empty() { return _ROOT == character; }
float _log_prob_b_prev; void remove();
float _log_prob_nb_prev;
float _log_prob_b_cur;
float _log_prob_nb_cur;
float _score;
float _approx_ctc;
float log_prob_b_prev;
float log_prob_nb_prev;
float log_prob_b_cur;
float log_prob_nb_cur;
float score;
float approx_ctc;
int character;
PathTrie* parent;
int _ROOT; private:
int _character; int _ROOT;
bool _exists; bool _exists;
PathTrie *_parent; std::vector<std::pair<int, PathTrie*>> _children;
std::vector<std::pair<int, PathTrie*> > _children;
fst::StdVectorFst* _dictionary; fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state; fst::StdVectorFst::StateId _dictionary_state;
bool _has_dictionary; bool _has_dictionary;
std::shared_ptr<FSTMATCH> _matcher; std::shared_ptr<FSTMATCH> _matcher;
}; };
#endif // PATH_TRIE_H #endif // PATH_TRIE_H
#include <iostream> #include "scorer.h"
#include <unistd.h> #include <unistd.h>
#include <iostream>
#include "decoder_utils.h"
#include "lm/config.hh" #include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/tokenize_piece.hh" #include "lm/state.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
#include "scorer.h" #include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram; using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
_is_character_based = true; _is_character_based = true;
_language_model = nullptr; _language_model = nullptr;
dictionary = nullptr; dictionary = nullptr;
_max_order = 0; _max_order = 0;
_SPACE_ID = -1; _SPACE_ID = -1;
// load language model // load language model
load_LM(lm_path.c_str()); load_LM(lm_path.c_str());
} }
Scorer::~Scorer() { Scorer::~Scorer() {
if (_language_model != nullptr) if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model); delete static_cast<lm::base::Model*>(_language_model);
if (dictionary != nullptr) if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary);
delete static_cast<fst::StdVectorFst*>(dictionary);
} }
void Scorer::load_LM(const char* filename) { void Scorer::load_LM(const char* filename) {
if (access(filename, F_OK) != 0) { if (access(filename, F_OK) != 0) {
std::cerr << "Invalid language model file !!!" << std::endl; std::cerr << "Invalid language model file !!!" << std::endl;
exit(1); exit(1);
} }
RetriveStrEnumerateVocab enumerate; RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config; lm::ngram::Config config;
config.enumerate_vocab = &enumerate; config.enumerate_vocab = &enumerate;
_language_model = lm::ngram::LoadVirtual(filename, config); _language_model = lm::ngram::LoadVirtual(filename, config);
_max_order = static_cast<lm::base::Model*>(_language_model)->Order(); _max_order = static_cast<lm::base::Model*>(_language_model)->Order();
_vocabulary = enumerate.vocabulary; _vocabulary = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) { for (size_t i = 0; i < _vocabulary.size(); ++i) {
if (_is_character_based if (_is_character_based && _vocabulary[i] != UNK_TOKEN &&
&& _vocabulary[i] != UNK_TOKEN _vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN &&
&& _vocabulary[i] != START_TOKEN get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
&& _vocabulary[i] != END_TOKEN _is_character_based = false;
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false;
}
} }
}
} }
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) { double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model); lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
double cond_prob; double cond_prob;
lm::ngram::State state, tmp_state, out_state; lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin // avoid to inserting <s> in begin
model->NullContextWrite(&state); model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) { for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV // encounter OOV
if (word_index == 0) { if (word_index == 0) {
return OOV_SCORE; return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
} }
// log10 prob cond_prob = model->BaseScore(&state, word_index, &out_state);
return cond_prob; tmp_state = state;
state = out_state;
out_state = tmp_state;
}
// log10 prob
return cond_prob;
} }
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) { double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence; std::vector<std::string> sentence;
if (words.size() == 0) { if (words.size() == 0) {
for (size_t i = 0; i < _max_order; ++i) { for (size_t i = 0; i < _max_order; ++i) {
sentence.push_back(START_TOKEN); sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < _max_order - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
} }
sentence.push_back(END_TOKEN); } else {
return get_log_prob(sentence); for (size_t i = 0; i < _max_order - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
} }
double Scorer::get_log_prob(const std::vector<std::string>& words) { double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > _max_order); assert(words.size() > _max_order);
double score = 0.0; double score = 0.0;
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i, std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + _max_order); words.begin() + i + _max_order);
score += get_log_cond_prob(ngram); score += get_log_cond_prob(ngram);
} }
return score; return score;
} }
void Scorer::reset_params(float alpha, float beta) { void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
} }
std::string Scorer::vec2str(const std::vector<int>& input) { std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word; std::string word;
for (auto ind : input) { for (auto ind : input) {
word += _char_list[ind]; word += _char_list[ind];
} }
return word; return word;
} }
std::vector<std::string> std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
Scorer::split_labels(const std::vector<int> &labels) { if (labels.empty()) return {};
if (labels.empty())
return {}; std::string s = vec2str(labels);
std::vector<std::string> words;
std::string s = vec2str(labels); if (_is_character_based) {
std::vector<std::string> words; words = split_utf8_str(s);
if (_is_character_based) { } else {
words = split_utf8_str(s); words = split_str(s, " ");
} else { }
words = split_str(s, " "); return words;
}
return words;
} }
void Scorer::set_char_map(std::vector<std::string> char_list) { void Scorer::set_char_map(std::vector<std::string> char_list) {
_char_list = char_list; _char_list = char_list;
_char_map.clear(); _char_map.clear();
for(unsigned int i = 0; i < _char_list.size(); i++) for (unsigned int i = 0; i < _char_list.size(); i++) {
{ if (_char_list[i] == " ") {
if (_char_list[i] == " ") { _SPACE_ID = i;
_SPACE_ID = i; _char_map[' '] = i;
_char_map[' '] = i; } else if (_char_list[i].size() == 1) {
} else if(_char_list[i].size() == 1){ _char_map[_char_list[i][0]] = i;
_char_map[_char_list[i][0]] = i;
}
} }
}
} }
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) { std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram; std::vector<std::string> ngram;
PathTrie* current_node = prefix; PathTrie* current_node = prefix;
PathTrie* new_node = nullptr; PathTrie* new_node = nullptr;
for (int order = 0; order < _max_order; order++) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
current_node = new_node->_parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->_character == -1) {
// No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary; for (int order = 0; order < _max_order; order++) {
// First reverse char_list so ints can be accessed by chars std::vector<int> prefix_vec;
std::unordered_map<std::string, int> char_map;
for (unsigned int i = 0; i < _char_list.size(); i++) {
char_map[_char_list[i]] = i;
}
// For each unigram convert to ints and put in trie if (_is_character_based) {
int vocab_size = 0; new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
for (const auto& word : _vocabulary) { current_node = new_node;
bool added = add_word_to_dictionary(word, } else {
char_map, new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
add_space, current_node = new_node->parent; // Skipping spaces
_SPACE_ID,
&dictionary);
vocab_size += added ? 1 : 0;
} }
std::cerr << "Vocab Size " << vocab_size << std::endl; // reconstruct word
std::string word = vec2str(prefix_vec);
// Simplify FST ngram.push_back(word);
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
// This makes the FST deterministic, meaning for any string input there's if (new_node->character == -1) {
// only one possible state the FST could be in. It is assumed our // No more spaces, but still need order
// dictionary is deterministic when using it. for (int i = 0; i < _max_order - order - 1; i++) {
// (lest we'd have to check for multiple transitions at each state) ngram.push_back(START_TOKEN);
fst::Determinize(dictionary, new_dict); }
break;
// Finds the simplest equivalent fst. This is unnecessary but decreases }
// memory usage of the dictionary }
fst::Minimize(new_dict); std::reverse(ngram.begin(), ngram.end());
this->dictionary = new_dict; return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map;
for (unsigned int i = 0; i < _char_list.size(); i++) {
char_map[_char_list[i]] = i;
}
// For each unigram convert to ints and put in trie
int vocab_size = 0;
for (const auto& word : _vocabulary) {
bool added = add_word_to_dictionary(
word, char_map, add_space, _SPACE_ID, &dictionary);
vocab_size += added ? 1 : 0;
}
std::cerr << "Vocab Size " << vocab_size << std::endl;
// Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
fst::Determinize(dictionary, new_dict);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
fst::Minimize(new_dict);
this->dictionary = new_dict;
} }
#ifndef SCORER_H_ #ifndef SCORER_H_
#define SCORER_H_ #define SCORER_H_
#include <string>
#include <memory> #include <memory>
#include <vector> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh" #include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh" #include "lm/virtual_interface.hh"
#include "util/string_piece.hh" #include "lm/word_index.hh"
#include "path_trie.h" #include "path_trie.h"
#include "util/string_piece.hh"
const double OOV_SCORE = -1000.0; const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>"; const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>"; const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>"; const std::string END_TOKEN = "</s>";
// Implement a callback to retrive string vocabulary. // Implement a callback to retrive string vocabulary.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab { class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece& str) { void Add(lm::WordIndex index, const StringPiece& str) {
vocabulary.push_back(std::string(str.data(), str.length())); vocabulary.push_back(std::string(str.data(), str.length()));
} }
std::vector<std::string> vocabulary; std::vector<std::string> vocabulary;
}; };
// External scorer to query languange score for n-gram or sentence. // External scorer to query languange score for n-gram or sentence.
...@@ -33,59 +33,59 @@ public: ...@@ -33,59 +33,59 @@ public:
// Scorer scorer(alpha, beta, "path_of_language_model"); // Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{ class Scorer {
public: public:
Scorer(double alpha, double beta, const std::string& lm_path); Scorer(double alpha, double beta, const std::string& lm_path);
~Scorer(); ~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words); double get_log_cond_prob(const std::vector<std::string>& words);
double get_sent_log_prob(const std::vector<std::string>& words); double get_sent_log_prob(const std::vector<std::string>& words);
size_t get_max_order() { return _max_order; } size_t get_max_order() { return _max_order; }
bool is_char_map_empty() {return _char_map.size() == 0; } bool is_char_map_empty() { return _char_map.size() == 0; }
bool is_character_based() { return _is_character_based; } bool is_character_based() { return _is_character_based; }
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// make ngram // make ngram
std::vector<std::string> make_ngram(PathTrie* prefix); std::vector<std::string> make_ngram(PathTrie* prefix);
// fill dictionary for fst // fill dictionary for fst
void fill_dictionary(bool add_space); void fill_dictionary(bool add_space);
// set char map // set char map
void set_char_map(std::vector<std::string> char_list); void set_char_map(std::vector<std::string> char_list);
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int>& labels);
// expose to decoder // expose to decoder
double alpha; double alpha;
double beta; double beta;
// fst dictionary // fst dictionary
void* dictionary; void* dictionary;
protected: protected:
void load_LM(const char* filename); void load_LM(const char* filename);
double get_log_prob(const std::vector<std::string>& words); double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int>& input);
private: private:
void* _language_model; void* _language_model;
bool _is_character_based; bool _is_character_based;
size_t _max_order; size_t _max_order;
int _SPACE_ID; int _SPACE_ID;
std::vector<std::string> _char_list; std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map; std::unordered_map<char, int> _char_map;
std::vector<std::string> _vocabulary; std::vector<std::string> _vocabulary;
}; };
#endif // SCORER_H_ #endif // SCORER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册