提交 b5602054 编写于 作者: Y Yibing Liu

convert data structure for prefix from map to trie tree

上级 eef364d1
......@@ -18,7 +18,7 @@ import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=32,
default=5,
type=int,
help="Number of samples for inference. (default: %(default)s)")
parser.add_argument(
......@@ -79,7 +79,7 @@ parser.add_argument(
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=200,
default=20,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
......@@ -104,7 +104,7 @@ parser.add_argument(
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
default=1.0,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
......@@ -183,7 +183,8 @@ def infer():
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
# ext_scoring_func=ext_scorer,
)
batch_beam_results += [beam_result]
else:
batch_beam_results = ctc_beam_search_decoder_batch(
......
......@@ -4,11 +4,13 @@
#include <utility>
#include <cmath>
#include <limits>
#include "fst/fstlib.h"
#include "ctc_decoders.h"
#include "decoder_utils.h"
#include "path_trie.h"
#include "ThreadPool.h"
typedef double log_prob_type;
typedef float log_prob_type;
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary)
......@@ -89,24 +91,30 @@ std::vector<std::pair<double, std::string> >
exit(1);
}
// initialize
// two sets containing selected and candidate prefixes respectively
std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
// probability of prefixes ending with blank and non-blank
std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;
static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
prefix_set_prev["\t"] = 0.0;
log_probs_b_prev["\t"] = 0.0;
log_probs_nb_prev["\t"] = -NUM_MAX;
for (int time_step=0; time_step<num_time_steps; time_step++) {
prefix_set_next.clear();
log_probs_b_cur.clear();
log_probs_nb_cur.clear();
std::vector<double> prob = probs_seq[time_step];
static log_prob_type POS_INF = std::numeric_limits<log_prob_type>::max();
static log_prob_type NEG_INF = -POS_INF;
static log_prob_type NUM_MIN = std::numeric_limits<log_prob_type>::min();
// init
PathTrie root;
root._log_prob_b_prev = 0.0;
root._score = 0.0;
std::vector<PathTrie*> prefixes;
prefixes.push_back(&root);
if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer->dictionary == nullptr) {
// TODO: init dictionary
}
auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
for (int time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double> > prob_idx;
for (int i=0; i<prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
......@@ -132,113 +140,134 @@ std::vector<std::pair<double, std::string> >
std::vector<std::pair<int, log_prob_type> > log_prob_idx;
for (int i=0; i<cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, log_prob_type>
(prob_idx[i].first, log(prob_idx[i].second)));
(prob_idx[i].first, log(prob_idx[i].second + NUM_MIN)));
}
// extend prefix
for (std::map<std::string, log_prob_type>::iterator
it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) {
std::string l = it->first;
if( prefix_set_next.find(l) == prefix_set_next.end()) {
log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
}
// loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
log_prob_type log_prob_c = log_prob_idx[index].second;
//log_prob_type log_probs_prev;
for (int index=0; index<log_prob_idx.size(); index++) {
int c = log_prob_idx[index].first;
log_prob_type log_prob_c = log_prob_idx[index].second;
log_prob_type log_probs_prev;
for (int i = 0; i < prefixes.size() && i<beam_size; i++) {
auto prefix = prefixes[i];
// blank
if (c == blank_id) {
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
log_probs_nb_prev[l]);
log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
log_prob_c+log_probs_prev);
} else {
std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c];
std::string l_plus = l + new_char;
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
log_probs_b_cur[l_plus] = -NUM_MAX;
log_probs_nb_cur[l_plus] = -NUM_MAX;
prefix->_log_prob_b_cur = log_sum_exp(
prefix->_log_prob_b_cur,
log_prob_c + prefix->_score);
continue;
}
// repeated character
if (c == prefix->_character) {
prefix->_log_prob_nb_cur = log_sum_exp(
prefix->_log_prob_nb_cur,
log_prob_c + prefix->_log_prob_nb_prev
);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = NEG_INF;
if (c == prefix->_character
&& prefix->_log_prob_b_prev > NEG_INF) {
log_p = log_prob_c + prefix->_log_prob_b_prev;
} else if (c != prefix->_character) {
log_p = log_prob_c + prefix->_score;
}
if (last_char == new_char) {
log_probs_nb_cur[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
log_prob_c+log_probs_b_prev[l]
);
log_probs_nb_cur[l] = log_sum_exp(
log_probs_nb_cur[l],
log_prob_c+log_probs_nb_prev[l]
);
} else if (new_char == " ") {
float score = 0.0;
if (ext_scorer != NULL && l.size() > 1) {
score = ext_scorer->get_score(l.substr(1), true);
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based()) ) {
PathTrie *prefix_to_score = nullptr;
// don't score the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
score + log_prob_c + log_probs_prev
);
} else {
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
log_prob_c+log_probs_prev
);
double score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_set_next[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
log_probs_b_cur[l_plus]
);
prefix_new->_log_prob_nb_cur = log_sum_exp(
prefix_new->_log_prob_nb_cur, log_p);
}
}
prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
log_probs_nb_cur[l]);
} // end of loop over chars
prefixes.clear();
// update log probabilities
root.iterate_to_vec(prefixes);
// sort prefixes by score
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); i++) {
prefixes[i]->remove();
}
}
}
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
double approx_ctc = prefixes[i]->_score;
// remove word insert:
std::vector<int> output;
prefixes[i]->get_path_vec(output);
size_t prefix_length = output.size();
// remove language model weight:
if (ext_scorer != nullptr) {
// auto words = split_labels(output);
// approx_ctc = approx_ctc - path_length * ext_scorer->beta;
// approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha;
}
log_probs_b_prev = log_probs_b_cur;
log_probs_nb_prev = log_probs_nb_cur;
std::vector<std::pair<std::string, log_prob_type> >
prefix_vec_next(prefix_set_next.begin(),
prefix_set_next.end());
std::sort(prefix_vec_next.begin(),
prefix_vec_next.end(),
pair_comp_second_rev<std::string, log_prob_type>);
int num_prefixes_next = prefix_vec_next.size();
int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
prefix_set_prev = std::map<std::string, log_prob_type> (
prefix_vec_next.begin(),
prefix_vec_next.begin() + k
);
prefixes[i]->_approx_ctc = approx_ctc;
}
// post processing
std::vector<std::pair<double, std::string> > beam_result;
for (std::map<std::string, log_prob_type>::iterator
it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
if (it->second > -NUM_MAX && it->first.size() > 1) {
log_prob_type log_prob = it->second;
std::string sentence = it->first.substr(1);
// scoring the last word
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
log_prob = log_prob + ext_scorer->get_score(sentence, true);
}
if (log_prob > -NUM_MAX) {
std::pair<double, std::string> cur_result(log_prob, sentence);
beam_result.push_back(cur_result);
}
// allow for the post processing
std::vector<PathTrie*> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) {
space_prefixes.push_back(prefixes[i]);
}
}
// sort the result and return
std::sort(beam_result.begin(), beam_result.end(),
pair_comp_first_rev<double, std::string>);
return beam_result;
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string> > output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (int j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(space_prefixes[i]->_score,
output_str);
output_vecs.emplace_back(
output_pair
);
}
return output_vecs;
}
std::vector<std::vector<std::pair<double, std::string>>>
......@@ -250,8 +279,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
int num_processes,
double cutoff_prob,
Scorer *ext_scorer
)
{
) {
if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1);
......
......@@ -10,3 +10,73 @@ size_t get_utf8_str_len(const std::string& str) {
}
return str_len;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
bool prefix_compare(const PathTrie* x, const PathTrie* y) {
if (x->_score == y->_score) {
if (x->_character == y->_character) {
return false;
} else {
return (x->_character < y->_character);
}
} else {
return x->_score > y->_score;
}
} //---------- End path_compare ---------------------------
// --------------------------------------------------------------
// Adds word to fst without copying entire dictionary
// --------------------------------------------------------------
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0);
dictionary->SetStart(start);
}
fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst;
for (auto c : word) {
dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst;
}
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
} // ------------ End of add_word_to_fst -----------------------
// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
bool addWordToDictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
fst::StdVectorFst* dictionary) {
/*
auto characters = UTF8_split(word);
std::vector<int> int_word;
for (auto& c : characters) {
if (c == " ") {
int_word.push_back(SPACE);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
int_word.push_back(int_c->second);
} else {
return false; // return without adding
}
}
}
if (add_space) {
int_word.push_back(SPACE);
}
add_word_to_fst(int_word, dictionary);
*/
return true;
} // -------------- End of addWordToDictionary ------------
......@@ -2,6 +2,7 @@
#define DECODER_UTILS_H_
#include <utility>
#include "path_trie.h"
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
......@@ -25,8 +26,21 @@ T log_sum_exp(const T &x, const T &y)
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
bool prefix_compare(const PathTrie* x, const PathTrie* y);
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary);
bool addWordToDictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
fst::StdVectorFst* dictionary);
#endif // DECODER_UTILS_H
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "path_trie.h"
#include "decoder_utils.h"
PathTrie::PathTrie() {
float lowest = -1.0*std::numeric_limits<float>::max();
_log_prob_b_prev = lowest;
_log_prob_nb_prev = lowest;
_log_prob_b_cur = lowest;
_log_prob_nb_cur = lowest;
_score = lowest;
_ROOT = -1;
_character = _ROOT;
_exists = true;
_parent = nullptr;
_dictionary = nullptr;
_dictionary_state = 0;
_has_dictionary = false;
_matcher = nullptr; // finds arcs in FST
}
PathTrie::~PathTrie() {
for (auto child : _children) {
delete child.second;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = _children.begin();
for (child = _children.begin(); child != _children.end(); ++child) {
if (child->first == new_char) {
break;
}
}
if ( child != _children.end() ) {
if (!child->second->_exists) {
child->second->_exists = true;
float lowest = -1.0*std::numeric_limits<float>::max();
child->second->_log_prob_b_prev = lowest;
child->second->_log_prob_nb_prev = lowest;
child->second->_log_prob_b_cur = lowest;
child->second->_log_prob_nb_cur = lowest;
}
return (child->second);
} else {
if (_has_dictionary) {
_matcher->SetState(_dictionary_state);
bool found = _matcher->Find(new_char);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = _dictionary->Final(_dictionary_state);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
_dictionary_state = _dictionary->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->_character = new_char;
new_path->_parent = this;
new_path->_dictionary = _dictionary;
new_path->_dictionary_state = _matcher->Value().nextstate;
new_path->_has_dictionary = true;
new_path->_matcher = _matcher;
_children.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->_character = new_char;
new_path->_parent = this;
_children.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, _ROOT);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps /*= std::numeric_limits<size_t>::max() */) {
if (_character == stop ||
_character == _ROOT ||
output.size() == max_steps) {
std::reverse(output.begin(), output.end());
return this;
} else {
output.push_back(_character);
return _parent->get_path_vec(output, stop, max_steps);
}
}
void PathTrie::iterate_to_vec(
std::vector<PathTrie*>& output) {
if (_exists) {
_log_prob_b_prev = _log_prob_b_cur;
_log_prob_nb_prev = _log_prob_nb_cur;
_log_prob_b_cur = -1.0 * std::numeric_limits<float>::max();
_log_prob_nb_cur = -1.0 * std::numeric_limits<float>::max();
_score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev);
output.push_back(this);
}
for (auto child : _children) {
child.second->iterate_to_vec(output);
}
}
//-------------------------------------------------------
// Effectively removes node
//-------------------------------------------------------
void PathTrie::remove() {
_exists = false;
if (_children.size() == 0) {
auto child = _parent->_children.begin();
for (child = _parent->_children.begin();
child != _parent->_children.end(); ++child) {
if (child->first == _character) {
_parent->_children.erase(child);
break;
}
}
if ( _parent->_children.size() == 0 && !_parent->_exists ) {
_parent->remove();
}
delete this;
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
_dictionary = dictionary;
_dictionary_state = dictionary->Start();
_has_dictionary = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
_matcher = matcher;
}
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include <fst/fstlib.h>
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
class PathTrie {
public:
PathTrie();
~PathTrie();
PathTrie* get_path_trie(int new_char, bool reset = true);
PathTrie* get_path_vec(std::vector<int> &output);
PathTrie* get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
void iterate_to_vec(std::vector<PathTrie*> &output);
void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<FSTMATCH> matcher);
bool is_empty() {
return _ROOT == _character;
}
void remove();
float _log_prob_b_prev;
float _log_prob_nb_prev;
float _log_prob_b_cur;
float _log_prob_nb_cur;
float _score;
float _approx_ctc;
int _ROOT;
int _character;
bool _exists;
PathTrie *_parent;
std::vector<std::pair<int, PathTrie*> > _children;
fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state;
bool _has_dictionary;
std::shared_ptr<FSTMATCH> _matcher;
};
#endif // PATH_TRIE_H
......@@ -175,3 +175,42 @@ double Scorer::get_score(std::string sentence, bool log) {
}
return final_score;
}
//--------------------------------------------------
// Turn indices back into strings of chars
//--------------------------------------------------
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
/*
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < _max_order; order++) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, ' ', 1);
current_node = new_node;
} else {
new_node = current_node->getPathVec(prefix_vec, ' ');
current_node = new_node->_parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->_character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order - order - 1; i++) {
ngram.push_back("<s>");
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
*/
std::vector<std::string> ngram;
ngram.push_back("this");
return ngram;
} //---------------- End makeNgrams ------------------
......@@ -4,10 +4,12 @@
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
#include "path_trie.h"
const double OOV_SCOER = -1000.0;
const std::string START_TOKEN = "<s>";
......@@ -49,18 +51,29 @@ public:
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string, bool log=false);
// make ngram
std::vector<std::string> make_ngram(PathTrie* prefix);
// expose to decoder
double alpha;
double beta;
// fst dictionary
void* dictionary;
protected:
void load_LM(const char* filename);
double get_log_prob(const std::vector<std::string>& words);
private:
void _init_char_list();
void _init_char_map();
void* _language_model;
bool _is_character_based;
size_t _max_order;
std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map;
std::vector<std::string> _vocabulary;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册