提交 955d2932 编写于 作者: Y Yibing Liu

enable finite-state transducer in beam search decoding

上级 d68732b7
......@@ -18,7 +18,7 @@ import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=5,
default=4,
type=int,
help="Number of samples for inference. (default: %(default)s)")
parser.add_argument(
......@@ -89,7 +89,8 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/common_crawl_00.prune01111.trie.klm",
default="/home/work/liuyibing/lm_bak/common_crawl_00.prune01111.trie.klm",
#default="ptb_all.arpa",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
......@@ -183,8 +184,7 @@ def infer():
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
# ext_scoring_func=ext_scorer,
)
ext_scoring_func=ext_scorer, )
batch_beam_results += [beam_result]
else:
batch_beam_results = ctc_beam_search_decoder_batch(
......
......@@ -103,10 +103,13 @@ std::vector<std::pair<double, std::string> >
prefixes.push_back(&root);
if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer->dictionary == nullptr) {
if (ext_scorer->_dictionary == nullptr) {
// TODO: init dictionary
ext_scorer->set_char_map(vocabulary);
// add_space should be true?
ext_scorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->dictionary);
auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->_dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
......@@ -288,6 +291,14 @@ std::vector<std::vector<std::pair<double, std::string>>>
ThreadPool pool(num_processes);
// number of samples
int batch_size = probs_split.size();
// dictionary init
if ( ext_scorer != nullptr) {
if (ext_scorer->_dictionary == nullptr) {
// TODO: init dictionary
ext_scorer->set_char_map(vocabulary);
ext_scorer->fill_dictionary(true);
}
}
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) {
......
......@@ -11,6 +11,32 @@ size_t get_utf8_str_len(const std::string& str) {
return str_len;
}
//------------------------------------------------------
//Splits string into vector of strings representing
//UTF-8 characters (not same as chars)
//------------------------------------------------------
std::vector<std::string> UTF8_split(const std::string& str)
{
std::vector<std::string> result;
std::string out_str;
for (char c : str)
{
if ((c & 0xc0) != 0x80) //new UTF-8 character
{
if (!out_str.empty())
{
result.push_back(out_str);
out_str.clear();
}
}
out_str.append(1, c);
}
result.push_back(out_str);
return result;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
......@@ -49,12 +75,11 @@ void add_word_to_fst(const std::vector<int>& word,
// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
bool addWordToDictionary(const std::string& word,
bool add_word_to_dictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
fst::StdVectorFst* dictionary) {
/*
auto characters = UTF8_split(word);
std::vector<int> int_word;
......@@ -77,6 +102,5 @@ bool addWordToDictionary(const std::string& word,
}
add_word_to_fst(int_word, dictionary);
*/
return true;
} // -------------- End of addWordToDictionary ------------
......@@ -35,10 +35,12 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y);
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
std::vector<std::string> UTF8_split(const std::string &str);
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary);
bool addWordToDictionary(const std::string& word,
bool add_word_to_dictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
......
......@@ -15,7 +15,9 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->beta = beta;
_is_character_based = true;
_language_model = nullptr;
_dictionary = nullptr;
_max_order = 0;
_SPACE = -1;
// load language model
load_LM(lm_path.c_str());
}
......@@ -23,6 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
Scorer::~Scorer() {
if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model);
if (_dictionary != nullptr)
delete static_cast<fst::StdVectorFst*>(_dictionary);
}
void Scorer::load_LM(const char* filename) {
......@@ -176,11 +180,83 @@ double Scorer::get_score(std::string sentence, bool log) {
return final_score;
}
//--------------------------------------------------
// Turn indices back into strings of chars
//--------------------------------------------------
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += _char_list[ind];
}
return word;
}
std::vector<std::string>
Scorer::split_labels(const std::vector<int> &labels) {
if (labels.empty())
return {};
std::string s = vec2str(labels);
std::vector<std::string> words;
if (_is_character_based) {
words = UTF8_split(s);
} else {
words = split_str(s, " ");
}
return words;
}
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> Scorer::split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
//---------------------------------------------------
// Add index to char list for searching language model
//---------------------------------------------------
void Scorer::set_char_map(std::vector<std::string> char_list) {
_char_list = char_list;
std::string _SPACE_STR = " ";
for (unsigned int i = 0; i < _char_list.size(); i++) {
// if (_char_list[i] == _BLANK_STR) {
// _BLANK = i;
// } else
if (_char_list[i] == _SPACE_STR) {
_SPACE = i;
}
}
_char_map.clear();
for(unsigned int i = 0; i < _char_list.size(); i++)
{
if(i == (unsigned int)_SPACE){
_char_map[' '] = i;
}
else if(_char_list[i].size() == 1){
_char_map[_char_list[i][0]] = i;
}
}
} //------------- End of set_char_map ----------------
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
/*
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
......@@ -189,10 +265,10 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, ' ', 1);
new_node = current_node->get_path_vec(prefix_vec, _SPACE, 1);
current_node = new_node;
} else {
new_node = current_node->getPathVec(prefix_vec, ' ');
new_node = current_node->get_path_vec(prefix_vec, _SPACE);
current_node = new_node->_parent; // Skipping spaces
}
......@@ -202,15 +278,60 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if (new_node->_character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order - order - 1; i++) {
for (int i = 0; i < _max_order - order - 1; i++) {
ngram.push_back("<s>");
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
*/
std::vector<std::string> ngram;
ngram.push_back("this");
return ngram;
} //---------------- End makeNgrams ------------------
}
//---------------------------------------------------------
// Helper function to populate Trie with a vocab using the
// char_list for maping from string to int
//---------------------------------------------------------
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map;
for (unsigned int i = 0; i < _char_list.size(); i++) {
char_map[_char_list[i]] = i;
}
// For each unigram convert to ints and put in trie
int vocab_size = 0;
for (const auto& word : _vocabulary) {
bool added = add_word_to_dictionary(word,
char_map,
add_space,
_SPACE,
&dictionary);
vocab_size += added ? 1 : 0;
}
std::cerr << "Vocab Size " << vocab_size << std::endl;
// Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
fst::Determinize(dictionary, new_dict);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
fst::Minimize(new_dict);
_dictionary = new_dict;
}
......@@ -53,15 +53,23 @@ public:
double get_score(std::string, bool log=false);
// make ngram
std::vector<std::string> make_ngram(PathTrie* prefix);
// fill dictionary for fst
void fill_dictionary(bool add_space);
// set char map
void set_char_map(std::vector<std::string> char_list);
// expose to decoder
double alpha;
double beta;
// fst dictionary
void* dictionary;
void* _dictionary;
protected:
void load_LM(const char* filename);
double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input);
std::vector<std::string> split_labels(const std::vector<int> &labels);
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
private:
void _init_char_list();
......@@ -71,6 +79,7 @@ private:
bool _is_character_based;
size_t _max_order;
unsigned int _SPACE;
std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册