提交 20d13a4d 编写于 作者: Y Yibing Liu

streamline source code

上级 955d2932
......@@ -10,8 +10,6 @@
#include "path_trie.h"
#include "ThreadPool.h"
typedef float log_prob_type;
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary)
{
......@@ -19,8 +17,8 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl;
std::cout << "The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl;
exit(1);
}
}
......@@ -30,8 +28,8 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<int> max_idx_vec;
double max_prob = 0.0;
int max_idx = 0;
for (int i=0; i<num_time_steps; i++) {
for (int j=0; j<probs_seq[i].size(); j++) {
for (int i = 0; i < num_time_steps; i++) {
for (int j = 0; j < probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) {
max_idx = j;
max_prob = probs_seq[i][j];
......@@ -43,14 +41,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
}
std::vector<int> idx_vec;
for (int i=0; i<max_idx_vec.size(); i++) {
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
for (int i = 0; i < max_idx_vec.size(); i++) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (int i=0; i<idx_vec.size(); i++) {
for (int i = 0; i < idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
......@@ -68,8 +66,8 @@ std::vector<std::pair<double, std::string> >
{
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
for (int i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << " The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl;
exit(1);
......@@ -86,19 +84,14 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if(space_id >= vocabulary.size()) {
std::cout << " The character space is not in the vocabulary!"<<std::endl;
exit(1);
space_id = -2;
}
static log_prob_type POS_INF = std::numeric_limits<log_prob_type>::max();
static log_prob_type NEG_INF = -POS_INF;
static log_prob_type NUM_MIN = std::numeric_limits<log_prob_type>::min();
// init
PathTrie root;
root._log_prob_b_prev = 0.0;
root._score = 0.0;
root._score = root._log_prob_b_prev = 0.0;
std::vector<PathTrie*> prefixes;
prefixes.push_back(&root);
......@@ -140,17 +133,17 @@ std::vector<std::pair<double, std::string> >
prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<int, log_prob_type> > log_prob_idx;
for (int i=0; i<cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, log_prob_type>
(prob_idx[i].first, log(prob_idx[i].second + NUM_MIN)));
std::vector<std::pair<int, float> > log_prob_idx;
for (int i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float>
(prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
// loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
log_prob_type log_prob_c = log_prob_idx[index].second;
//log_prob_type log_probs_prev;
float log_prob_c = log_prob_idx[index].second;
//float log_probs_prev;
for (int i = 0; i < prefixes.size() && i<beam_size; i++) {
auto prefix = prefixes[i];
......@@ -165,17 +158,16 @@ std::vector<std::pair<double, std::string> >
if (c == prefix->_character) {
prefix->_log_prob_nb_cur = log_sum_exp(
prefix->_log_prob_nb_cur,
log_prob_c + prefix->_log_prob_nb_prev
);
log_prob_c + prefix->_log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = NEG_INF;
float log_p = -NUM_FLT_INF;
if (c == prefix->_character
&& prefix->_log_prob_b_prev > NEG_INF) {
&& prefix->_log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->_log_prob_b_prev;
} else if (c != prefix->_character) {
log_p = log_prob_c + prefix->_score;
......@@ -201,7 +193,6 @@ std::vector<std::pair<double, std::string> >
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->_log_prob_nb_cur = log_sum_exp(
prefix_new->_log_prob_nb_cur, log_p);
......@@ -273,7 +264,7 @@ std::vector<std::pair<double, std::string> >
}
std::vector<std::vector<std::pair<double, std::string>>>
std::vector<std::vector<std::pair<double, std::string> > >
ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size,
......@@ -292,12 +283,12 @@ std::vector<std::vector<std::pair<double, std::string>>>
// number of samples
int batch_size = probs_split.size();
// dictionary init
if ( ext_scorer != nullptr) {
if (ext_scorer->_dictionary == nullptr) {
// TODO: init dictionary
ext_scorer->set_char_map(vocabulary);
ext_scorer->fill_dictionary(true);
}
if ( ext_scorer != nullptr
&& !ext_scorer->is_character_based()
&& ext_scorer->_dictionary == nullptr) {
// init dictionary
ext_scorer->set_char_map(vocabulary);
ext_scorer->fill_dictionary(true);
}
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
......@@ -308,7 +299,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
);
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
std::vector<std::vector<std::pair<double, std::string> > > batch_results;
for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get());
}
......
......@@ -15,7 +15,7 @@ size_t get_utf8_str_len(const std::string& str) {
//Splits string into vector of strings representing
//UTF-8 characters (not same as chars)
//------------------------------------------------------
std::vector<std::string> UTF8_split(const std::string& str)
std::vector<std::string> split_utf8_str(const std::string& str)
{
std::vector<std::string> result;
std::string out_str;
......@@ -37,6 +37,29 @@ std::vector<std::string> UTF8_split(const std::string& str)
return result;
}
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
......@@ -80,7 +103,7 @@ bool add_word_to_dictionary(const std::string& word,
bool add_space,
int SPACE,
fst::StdVectorFst* dictionary) {
auto characters = UTF8_split(word);
auto characters = split_utf8_str(word);
std::vector<int> int_word;
......
......@@ -4,14 +4,19 @@
#include <utility>
#include "path_trie.h"
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b)
{
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b)
{
return a.second > b.second;
}
......@@ -26,16 +31,18 @@ T log_sum_exp(const T &x, const T &y)
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
// Functor for prefix comparsion
bool prefix_compare(const PathTrie* x, const PathTrie* y);
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
std::vector<std::string> UTF8_split(const std::string &str);
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
std::vector<std::string> split_utf8_str(const std::string &str);
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary);
......
......@@ -8,12 +8,11 @@
#include "decoder_utils.h"
PathTrie::PathTrie() {
float lowest = -1.0*std::numeric_limits<float>::max();
_log_prob_b_prev = lowest;
_log_prob_nb_prev = lowest;
_log_prob_b_cur = lowest;
_log_prob_nb_cur = lowest;
_score = lowest;
_log_prob_b_prev = -NUM_FLT_INF;
_log_prob_nb_prev = -NUM_FLT_INF;
_log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF;
_score = -NUM_FLT_INF;
_ROOT = -1;
_character = _ROOT;
......@@ -41,11 +40,10 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
if ( child != _children.end() ) {
if (!child->second->_exists) {
child->second->_exists = true;
float lowest = -1.0*std::numeric_limits<float>::max();
child->second->_log_prob_b_prev = lowest;
child->second->_log_prob_nb_prev = lowest;
child->second->_log_prob_b_cur = lowest;
child->second->_log_prob_nb_cur = lowest;
child->second->_log_prob_b_prev = -NUM_FLT_INF;
child->second->_log_prob_nb_prev = -NUM_FLT_INF;
child->second->_log_prob_b_cur = -NUM_FLT_INF;
child->second->_log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
} else {
......@@ -106,8 +104,8 @@ void PathTrie::iterate_to_vec(
_log_prob_b_prev = _log_prob_b_cur;
_log_prob_nb_prev = _log_prob_nb_cur;
_log_prob_b_cur = -1.0 * std::numeric_limits<float>::max();
_log_prob_nb_cur = -1.0 * std::numeric_limits<float>::max();
_log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF;
_score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev);
output.push_back(this);
......@@ -117,9 +115,6 @@ void PathTrie::iterate_to_vec(
}
}
//-------------------------------------------------------
// Effectively removes node
//-------------------------------------------------------
void PathTrie::remove() {
_exists = false;
......
......@@ -17,7 +17,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
_language_model = nullptr;
_dictionary = nullptr;
_max_order = 0;
_SPACE = -1;
_SPACE_ID = -1;
// load language model
load_LM(lm_path.c_str());
}
......@@ -61,7 +61,7 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCOER;
return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
......@@ -197,64 +197,27 @@ Scorer::split_labels(const std::vector<int> &labels) {
std::string s = vec2str(labels);
std::vector<std::string> words;
if (_is_character_based) {
words = UTF8_split(s);
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
return words;
}
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> Scorer::split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
//---------------------------------------------------
// Add index to char list for searching language model
//---------------------------------------------------
void Scorer::set_char_map(std::vector<std::string> char_list) {
_char_list = char_list;
std::string _SPACE_STR = " ";
for (unsigned int i = 0; i < _char_list.size(); i++) {
// if (_char_list[i] == _BLANK_STR) {
// _BLANK = i;
// } else
if (_char_list[i] == _SPACE_STR) {
_SPACE = i;
}
}
_char_map.clear();
for(unsigned int i = 0; i < _char_list.size(); i++)
{
if(i == (unsigned int)_SPACE){
if (_char_list[i] == " ") {
_SPACE_ID = i;
_char_map[' '] = i;
}
else if(_char_list[i].size() == 1){
} else if(_char_list[i].size() == 1){
_char_map[_char_list[i][0]] = i;
}
}
} //------------- End of set_char_map ----------------
}
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram;
......@@ -265,10 +228,10 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, _SPACE, 1);
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE);
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
current_node = new_node->_parent; // Skipping spaces
}
......@@ -279,7 +242,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if (new_node->_character == -1) {
// No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) {
ngram.push_back("<s>");
ngram.push_back(START_TOKEN);
}
break;
}
......@@ -288,10 +251,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
return ngram;
}
//---------------------------------------------------------
// Helper function to populate Trie with a vocab using the
// char_list for maping from string to int
//---------------------------------------------------------
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
......@@ -307,7 +266,7 @@ void Scorer::fill_dictionary(bool add_space) {
bool added = add_word_to_dictionary(word,
char_map,
add_space,
_SPACE,
_SPACE_ID,
&dictionary);
vocab_size += added ? 1 : 0;
}
......
......@@ -11,7 +11,7 @@
#include "util/string_piece.hh"
#include "path_trie.h"
const double OOV_SCOER = -1000.0;
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
......@@ -68,18 +68,13 @@ protected:
double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input);
std::vector<std::string> split_labels(const std::vector<int> &labels);
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
private:
void _init_char_list();
void _init_char_map();
void* _language_model;
bool _is_character_based;
size_t _max_order;
unsigned int _SPACE;
int _SPACE_ID;
std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册