提交 41e9e59d 编写于 作者: Y Yibing Liu

append some comments

上级 d75f27df
......@@ -14,8 +14,8 @@
#include "path_trie.h"
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary) {
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
int num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) {
......@@ -60,7 +60,7 @@ std::string ctc_greedy_decoder(
}
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::vector<double>> &probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
......@@ -104,7 +104,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
if (!extscorer->is_character_based()) {
if (extscorer->dictionary == nullptr) {
// fill dictionary for fst
// fill dictionary for fst with space
extscorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary);
......@@ -282,9 +282,9 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>>& probs_split,
const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size,
const std::vector<std::string>& vocabulary,
const std::vector<std::string> &vocabulary,
int blank_id,
int num_processes,
double cutoff_prob,
......@@ -304,8 +304,7 @@ ctc_beam_search_decoder_batch(
if (extscorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary);
}
if (!extscorer->is_character_based() &&
extscorer->dictionary == nullptr) {
if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) {
// init dictionary
extscorer->fill_dictionary(true);
}
......
......@@ -14,12 +14,11 @@
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
* The decoding result in string
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary);
/* CTC Beam Search Decoder
......@@ -37,7 +36,7 @@ std::string ctc_greedy_decoder(
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::vector<double>> &probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
......@@ -59,14 +58,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix.
* Return:
* A 2-D vector that each element is a vector of decoding result for one
* sample.
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>>& probs_split,
const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size,
const std::vector<std::string>& vocabulary,
const std::vector<std::string> &vocabulary,
int blank_id,
int num_processes,
double cutoff_prob = 1.0,
......
......@@ -4,7 +4,7 @@
#include <cmath>
#include <limits>
size_t get_utf8_str_len(const std::string& str) {
size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
......@@ -12,7 +12,7 @@ size_t get_utf8_str_len(const std::string& str) {
return str_len;
}
std::vector<std::string> split_utf8_str(const std::string& str) {
std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result;
std::string out_str;
......@@ -31,8 +31,8 @@ std::vector<std::string> split_utf8_str(const std::string& str) {
return result;
}
std::vector<std::string> split_str(const std::string& s,
const std::string& delim) {
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
......@@ -51,7 +51,7 @@ std::vector<std::string> split_str(const std::string& s,
return result;
}
bool prefix_compare(const PathTrie* x, const PathTrie* y) {
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) {
if (x->character == y->character) {
return false;
......@@ -63,8 +63,8 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y) {
}
}
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary) {
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0);
......@@ -81,16 +81,16 @@ void add_word_to_fst(const std::vector<int>& word,
}
bool add_word_to_dictionary(
const std::string& word,
const std::unordered_map<std::string, int>& char_map,
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst* dictionary) {
fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word);
std::vector<int> int_word;
for (auto& c : characters) {
for (auto &c : characters) {
if (c == " ") {
int_word.push_back(SPACE_ID);
} else {
......@@ -108,5 +108,5 @@ bool add_word_to_dictionary(
}
add_word_to_fst(int_word, dictionary);
return true;
return true; // return with successful adding
}
......@@ -14,12 +14,14 @@ bool pair_comp_first_rev(const std::pair<T1, T2> &a,
return a.first > b.first;
}
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.second > b.second;
}
// Return the sum of two probabilities in log scale
template <typename T>
T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max();
......@@ -32,18 +34,21 @@ T log_sum_exp(const T &x, const T &y) {
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str);
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
// Splits string into vector of strings representing
// UTF-8 characters (not same as chars)
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
......
......@@ -22,7 +22,7 @@ PathTrie::PathTrie() {
_dictionary = nullptr;
_dictionary_state = 0;
_has_dictionary = false;
_matcher = nullptr; // finds arcs in FST
_matcher = nullptr;
}
PathTrie::~PathTrie() {
......
......@@ -10,27 +10,36 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class PathTrie {
public:
PathTrie();
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<FSTMATCH> matcher);
bool is_empty() { return _ROOT == character; }
// remove current path from root
void remove();
float log_prob_b_prev;
......@@ -49,8 +58,10 @@ private:
std::vector<std::pair<int, PathTrie*>> _children;
// pointer to dictionary of FST
fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state;
// true if finding ars in FST
std::shared_ptr<FSTMATCH> _matcher;
};
......
......@@ -68,7 +68,7 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
state = out_state;
out_state = tmp_state;
}
// log10 prob
// return log10 prob
return cond_prob;
}
......@@ -189,23 +189,26 @@ void Scorer::fill_dictionary(bool add_space) {
std::cerr << "Vocab Size " << vocab_size << std::endl;
// Simplify FST
/* Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary = new_dict;
}
......@@ -23,14 +23,15 @@ class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece& str) {
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
/* External scorer to query languange score for n-gram or sentence.
/* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
......@@ -39,12 +40,12 @@ public:
*/
class Scorer {
public:
Scorer(double alpha, double beta, const std::string& lm_path);
Scorer(double alpha, double beta, const std::string &lm_path);
~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words);
double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string>& words);
double get_sent_log_prob(const std::vector<std::string> &words);
size_t get_max_order() { return _max_order; }
......@@ -56,32 +57,32 @@ public:
void reset_params(float alpha, float beta);
// make ngram
std::vector<std::string> make_ngram(PathTrie* prefix);
std::vector<std::string> make_ngram(PathTrie *prefix);
// fill dictionary for fst
void fill_dictionary(bool add_space);
// set char map
void set_char_map(const std::vector<std::string>& char_list);
void set_char_map(const std::vector<std::string> &char_list);
std::vector<std::string> split_labels(const std::vector<int>& labels);
std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder
double alpha;
double beta;
// fst dictionary
void* dictionary;
void *dictionary;
protected:
void load_LM(const char* filename);
void load_LM(const char *filename);
double get_log_prob(const std::vector<std::string>& words);
double get_log_prob(const std::vector<std::string> &words);
std::string vec2str(const std::vector<int>& input);
std::string vec2str(const std::vector<int> &input);
private:
void* _language_model;
void *_language_model;
bool _is_character_based;
size_t _max_order;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册