提交 f842c79a 编写于 作者: H Hui Zhang

format code

上级 e969a8ec
...@@ -36,169 +36,177 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -36,169 +36,177 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1, // vocabulary.size() + 1,
vocabulary.size(), vocabulary.size(),
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
}
// assign blank id
//size_t blank_id = vocabulary.size();
size_t blank_id = 0;
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
} }
std::vector<std::pair<size_t, float>> log_prob_idx = // assign blank id
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); // size_t blank_id = vocabulary.size();
// loop over chars size_t blank_id = 0;
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; // assign space id
auto log_prob_c = log_prob_idx[index].second; auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { // if no space in vocabulary
auto prefix = prefixes[i]; if ((size_t)space_id >= vocabulary.size()) {
if (full_beam && log_prob_c + prefix->score < min_cutoff) { space_id = -2;
break; }
}
// blank // init prefixes' root
if (c == blank_id) { PathTrie root;
prefix->log_prob_b_cur = root.score = root.log_prob_b_prev = 0.0;
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); std::vector<PathTrie *> prefixes;
continue; prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
} }
// repeated character
if (c == prefix->character) { std::vector<std::pair<size_t, float>> log_prob_idx =
prefix->log_prob_nb_cur = log_sum_exp( get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); // loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} }
// get new prefix } // end of loop over time
auto prefix_new = prefix->get_path_trie(c);
// score the last word of each prefix that doesn't end with space
if (prefix_new != nullptr) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
float log_p = -NUM_FLT_INF; for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (c == prefix->character && if (!prefix->is_empty() && prefix->character != space_id) {
prefix->log_prob_b_prev > -NUM_FLT_INF) { float score = 0.0;
log_p = log_prob_c + prefix->log_prob_b_prev; std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
} else if (c != prefix->character) { score =
log_p = log_prob_c + prefix->score; ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
} score += ext_scorer->beta;
prefix->score += score;
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
} }
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
} }
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} }
} // end of loop over time
// score the last word of each prefix that doesn't end with space size_t num_prefixes = std::min(prefixes.size(), beam_size);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i]; double approx_ctc = prefixes[i]->score;
if (!prefix->is_empty() && prefix->character != space_id) { if (ext_scorer != nullptr) {
float score = 0.0; std::vector<int> output;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix); prefixes[i]->get_path_vec(output);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; auto prefix_length = output.size();
score += ext_scorer->beta; auto words = ext_scorer->split_labels(output);
prefix->score += score; // remove word insert
} approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
} // remove language model weight:
} approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
size_t num_prefixes = std::min(prefixes.size(), beam_size); }
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); prefixes[i]->approx_ctc = approx_ctc;
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(prefixes, vocabulary, beam_size); return get_beam_search_result(prefixes, vocabulary, beam_size);
} }
...@@ -211,28 +219,28 @@ ctc_beam_search_decoder_batch( ...@@ -211,28 +219,28 @@ ctc_beam_search_decoder_batch(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
// number of samples // number of samples
size_t batch_size = probs_split.size(); size_t batch_size = probs_split.size();
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], probs_split[i],
vocabulary, vocabulary,
beam_size, beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer));
} }
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results; std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }
return batch_results; return batch_results;
} }
...@@ -18,42 +18,42 @@ ...@@ -18,42 +18,42 @@
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size() + 1,
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
size_t blank_id = vocabulary.size(); size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0); std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0; double max_prob = 0.0;
size_t max_idx = 0; size_t max_idx = 0;
const std::vector<double> &probs_step = probs_seq[i]; const std::vector<double> &probs_step = probs_seq[i];
for (size_t j = 0; j < probs_step.size(); ++j) { for (size_t j = 0; j < probs_step.size(); ++j) {
if (max_prob < probs_step[j]) { if (max_prob < probs_step[j]) {
max_idx = j; max_idx = j;
max_prob = probs_step[j]; max_prob = probs_step[j];
} }
} }
// id with maximum probability in current time step // id with maximum probability in current time step
max_idx_vec[i] = max_idx; max_idx_vec[i] = max_idx;
// deduplicate // deduplicate
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
}
} }
}
std::string best_path_result; std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) { for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]]; best_path_result += vocabulary[idx_vec[i]];
}
} }
} return best_path_result;
return best_path_result;
} }
...@@ -22,33 +22,35 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs( ...@@ -22,33 +22,35 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step, const std::vector<double> &prob_step,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n) { size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) { for (size_t i = 0; i < prob_step.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i])); prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
}
} }
prob_idx = std::vector<std::pair<int, double>>( // pruning of vacobulary
prob_idx.begin(), prob_idx.begin() + cutoff_len); size_t cutoff_len = prob_step.size();
} if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::vector<std::pair<size_t, float>> log_prob_idx; std::sort(prob_idx.begin(),
for (size_t i = 0; i < cutoff_len; ++i) { prob_idx.end(),
log_prob_idx.push_back(std::pair<int, float>( pair_comp_second_rev<int, double>);
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); if (cutoff_prob < 1.0) {
} double cum_prob = 0.0;
return log_prob_idx; cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
return log_prob_idx;
} }
...@@ -56,106 +58,106 @@ std::vector<std::pair<double, std::string>> get_beam_search_result( ...@@ -56,106 +58,106 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size) { size_t beam_size) {
// allow for the post processing // allow for the post processing
std::vector<PathTrie *> space_prefixes; std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) { if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]); space_prefixes.push_back(prefixes[i]);
}
} }
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); std::vector<std::pair<double, std::string>> output_vecs;
std::vector<std::pair<double, std::string>> output_vecs; for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { std::vector<int> output;
std::vector<int> output; space_prefixes[i]->get_path_vec(output);
space_prefixes[i]->get_path_vec(output); // convert index to string
// convert index to string std::string output_str;
std::string output_str; for (size_t j = 0; j < output.size(); j++) {
for (size_t j = 0; j < output.size(); j++) { output_str += vocabulary[output[j]];
output_str += vocabulary[output[j]]; }
std::pair<double, std::string> output_pair(
-space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair);
} }
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
output_vecs.emplace_back(output_pair);
}
return output_vecs; return output_vecs;
} }
size_t get_utf8_str_len(const std::string &str) { size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0; size_t str_len = 0;
for (char c : str) { for (char c : str) {
str_len += ((c & 0xc0) != 0x80); str_len += ((c & 0xc0) != 0x80);
} }
return str_len; return str_len;
} }
std::vector<std::string> split_utf8_str(const std::string &str) { std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result; std::vector<std::string> result;
std::string out_str; std::string out_str;
for (char c : str) { for (char c : str) {
if ((c & 0xc0) != 0x80) // new UTF-8 character if ((c & 0xc0) != 0x80) // new UTF-8 character
{ {
if (!out_str.empty()) { if (!out_str.empty()) {
result.push_back(out_str); result.push_back(out_str);
out_str.clear(); out_str.clear();
} }
}
out_str.append(1, c);
} }
result.push_back(out_str);
out_str.append(1, c); return result;
}
result.push_back(out_str);
return result;
} }
std::vector<std::string> split_str(const std::string &s, std::vector<std::string> split_str(const std::string &s,
const std::string &delim) { const std::string &delim) {
std::vector<std::string> result; std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size(); std::size_t start = 0, delim_len = delim.size();
while (true) { while (true) {
std::size_t end = s.find(delim, start); std::size_t end = s.find(delim, start);
if (end == std::string::npos) { if (end == std::string::npos) {
if (start < s.size()) { if (start < s.size()) {
result.push_back(s.substr(start)); result.push_back(s.substr(start));
} }
break; break;
} }
if (end > start) { if (end > start) {
result.push_back(s.substr(start, end - start)); result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
} }
start = end + delim_len; return result;
}
return result;
} }
bool prefix_compare(const PathTrie *x, const PathTrie *y) { bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) { if (x->score == y->score) {
if (x->character == y->character) { if (x->character == y->character) {
return false; return false;
} else {
return (x->character < y->character);
}
} else { } else {
return (x->character < y->character); return x->score > y->score;
} }
} else {
return x->score > y->score;
}
} }
void add_word_to_fst(const std::vector<int> &word, void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) { fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) { if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState(); fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0); assert(start == 0);
dictionary->SetStart(start); dictionary->SetStart(start);
} }
fst::StdVectorFst::StateId src = dictionary->Start(); fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst; fst::StdVectorFst::StateId dst;
for (auto c : word) { for (auto c : word) {
dst = dictionary->AddState(); dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst; src = dst;
} }
dictionary->SetFinal(dst, fst::StdArc::Weight::One()); dictionary->SetFinal(dst, fst::StdArc::Weight::One());
} }
bool add_word_to_dictionary( bool add_word_to_dictionary(
...@@ -164,27 +166,27 @@ bool add_word_to_dictionary( ...@@ -164,27 +166,27 @@ bool add_word_to_dictionary(
bool add_space, bool add_space,
int SPACE_ID, int SPACE_ID,
fst::StdVectorFst *dictionary) { fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word); auto characters = split_utf8_str(word);
std::vector<int> int_word; std::vector<int> int_word;
for (auto &c : characters) { for (auto &c : characters) {
if (c == " ") { if (c == " ") {
int_word.push_back(SPACE_ID); int_word.push_back(SPACE_ID);
} else { } else {
auto int_c = char_map.find(c); auto int_c = char_map.find(c);
if (int_c != char_map.end()) { if (int_c != char_map.end()) {
int_word.push_back(int_c->second); int_word.push_back(int_c->second);
} else { } else {
return false; // return without adding return false; // return without adding
} }
}
} }
}
if (add_space) { if (add_space) {
int_word.push_back(SPACE_ID); int_word.push_back(SPACE_ID);
} }
add_word_to_fst(int_word, dictionary); add_word_to_fst(int_word, dictionary);
return true; // return with successful adding return true; // return with successful adding
} }
...@@ -25,14 +25,14 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min(); ...@@ -25,14 +25,14 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check // inline function for validation check
inline void check( inline void check(
bool x, const char *expr, const char *file, int line, const char *err) { bool x, const char *expr, const char *file, int line, const char *err) {
if (!x) { if (!x) {
std::cout << "[" << file << ":" << line << "] "; std::cout << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err; LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
} }
} }
#define VALID_CHECK(x, info) \ #define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info) check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) #define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) #define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) #define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
...@@ -42,24 +42,24 @@ inline void check( ...@@ -42,24 +42,24 @@ inline void check(
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) { const std::pair<T1, T2> &b) {
return a.first > b.first; return a.first > b.first;
} }
// Function template for comparing two pairs // Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) { const std::pair<T1, T2> &b) {
return a.second > b.second; return a.second > b.second;
} }
// Return the sum of two probabilities in log scale // Return the sum of two probabilities in log scale
template <typename T> template <typename T>
T log_sum_exp(const T &x, const T &y) { T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max(); static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y; if (x <= num_min) return y;
if (y <= num_min) return x; if (y <= num_min) return x;
T xmax = std::max(x, y); T xmax = std::max(x, y);
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
} }
// Get pruned probability vector for each time step's beam search // Get pruned probability vector for each time step's beam search
......
...@@ -23,140 +23,141 @@ ...@@ -23,140 +23,141 @@
#include "decoder_utils.h" #include "decoder_utils.h"
PathTrie::PathTrie() { PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF; log_prob_b_prev = -NUM_FLT_INF;
log_prob_nb_prev = -NUM_FLT_INF; log_prob_nb_prev = -NUM_FLT_INF;
log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF; score = -NUM_FLT_INF;
ROOT_ = -1; ROOT_ = -1;
character = ROOT_; character = ROOT_;
exists_ = true; exists_ = true;
parent = nullptr; parent = nullptr;
dictionary_ = nullptr; dictionary_ = nullptr;
dictionary_state_ = 0; dictionary_state_ = 0;
has_dictionary_ = false; has_dictionary_ = false;
matcher_ = nullptr; matcher_ = nullptr;
} }
PathTrie::~PathTrie() { PathTrie::~PathTrie() {
for (auto child : children_) { for (auto child : children_) {
delete child.second; delete child.second;
} }
} }
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = children_.begin(); auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) { for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) { if (child->first == new_char) {
break; break;
} }
}
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
} }
return (child->second); if (child != children_.end()) {
} else { if (!child->second->exists_) {
if (has_dictionary_) { child->second->exists_ = true;
matcher_->SetState(dictionary_state_); child->second->log_prob_b_prev = -NUM_FLT_INF;
bool found = matcher_->Find(new_char + 1); child->second->log_prob_nb_prev = -NUM_FLT_INF;
if (!found) { child->second->log_prob_b_cur = -NUM_FLT_INF;
// Adding this character causes word outside dictionary child->second->log_prob_nb_cur = -NUM_FLT_INF;
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
} }
return nullptr; return (child->second);
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else { } else {
PathTrie* new_path = new PathTrie; if (has_dictionary_) {
new_path->character = new_char; matcher_->SetState(dictionary_state_);
new_path->parent = this; bool found = matcher_->Find(new_char + 1);
children_.push_back(std::make_pair(new_char, new_path)); if (!found) {
return new_path; // Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} }
}
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) { PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, ROOT_); return get_path_vec(output, ROOT_);
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps) { size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) { if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end()); std::reverse(output.begin(), output.end());
return this; return this;
} else { } else {
output.push_back(character); output.push_back(character);
return parent->get_path_vec(output, stop, max_steps); return parent->get_path_vec(output, stop, max_steps);
} }
} }
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) { void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (exists_) { if (exists_) {
log_prob_b_prev = log_prob_b_cur; log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur; log_prob_nb_prev = log_prob_nb_cur;
log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this); output.push_back(this);
} }
for (auto child : children_) { for (auto child : children_) {
child.second->iterate_to_vec(output); child.second->iterate_to_vec(output);
} }
} }
void PathTrie::remove() { void PathTrie::remove() {
exists_ = false; exists_ = false;
if (children_.size() == 0) { if (children_.size() == 0) {
auto child = parent->children_.begin(); auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end(); for (child = parent->children_.begin();
++child) { child != parent->children_.end();
if (child->first == character) { ++child) {
parent->children_.erase(child); if (child->first == character) {
break; parent->children_.erase(child);
} break;
} }
}
if (parent->children_.size() == 0 && !parent->exists_) { if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove(); parent->remove();
} }
delete this; delete this;
} }
} }
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary; dictionary_ = dictionary;
dictionary_state_ = dictionary->Start(); dictionary_state_ = dictionary->Start();
has_dictionary_ = true; has_dictionary_ = true;
} }
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) { void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
matcher_ = matcher; matcher_ = matcher;
} }
...@@ -27,55 +27,56 @@ ...@@ -27,55 +27,56 @@
* finite-state transducer for spelling correction. * finite-state transducer for spelling correction.
*/ */
class PathTrie { class PathTrie {
public: public:
PathTrie(); PathTrie();
~PathTrie(); ~PathTrie();
// get new prefix after appending new char // get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true); PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node // get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output); PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel // get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(
int stop, std::vector<int>& output,
size_t max_steps = std::numeric_limits<size_t>::max()); int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs // update log probs
void iterate_to_vec(std::vector<PathTrie*>& output); void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST // set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary); void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>); void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return ROOT_ == character; } bool is_empty() { return ROOT_ == character; }
// remove current path from root // remove current path from root
void remove(); void remove();
float log_prob_b_prev; float log_prob_b_prev;
float log_prob_nb_prev; float log_prob_nb_prev;
float log_prob_b_cur; float log_prob_b_cur;
float log_prob_nb_cur; float log_prob_nb_cur;
float score; float score;
float approx_ctc; float approx_ctc;
int character; int character;
PathTrie* parent; PathTrie* parent;
private: private:
int ROOT_; int ROOT_;
bool exists_; bool exists_;
bool has_dictionary_; bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> children_; std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST // pointer to dictionary of FST
fst::StdVectorFst* dictionary_; fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_; fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST // true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_; std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
}; };
#endif // PATH_TRIE_H #endif // PATH_TRIE_H
...@@ -31,214 +31,214 @@ Scorer::Scorer(double alpha, ...@@ -31,214 +31,214 @@ Scorer::Scorer(double alpha,
double beta, double beta,
const std::string& lm_path, const std::string& lm_path,
const std::vector<std::string>& vocab_list) { const std::vector<std::string>& vocab_list) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
dictionary = nullptr; dictionary = nullptr;
is_character_based_ = true; is_character_based_ = true;
language_model_ = nullptr; language_model_ = nullptr;
max_order_ = 0; max_order_ = 0;
dict_size_ = 0; dict_size_ = 0;
SPACE_ID_ = -1; SPACE_ID_ = -1;
setup(lm_path, vocab_list); setup(lm_path, vocab_list);
} }
Scorer::~Scorer() { Scorer::~Scorer() {
if (language_model_ != nullptr) { if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_); delete static_cast<lm::base::Model*>(language_model_);
} }
if (dictionary != nullptr) { if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary); delete static_cast<fst::StdVectorFst*>(dictionary);
} }
} }
void Scorer::setup(const std::string& lm_path, void Scorer::setup(const std::string& lm_path,
const std::vector<std::string>& vocab_list) { const std::vector<std::string>& vocab_list) {
// load language model // load language model
load_lm(lm_path); load_lm(lm_path);
// set char map for scorer // set char map for scorer
set_char_map(vocab_list); set_char_map(vocab_list);
// fill the dictionary for FST // fill the dictionary for FST
if (!is_character_based()) { if (!is_character_based()) {
fill_dictionary(true); fill_dictionary(true);
} }
} }
void Scorer::load_lm(const std::string& lm_path) { void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str(); const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate; RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config; lm::ngram::Config config;
config.enumerate_vocab = &enumerate; config.enumerate_vocab = &enumerate;
language_model_ = lm::ngram::LoadVirtual(filename, config); language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order(); max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary; vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) { for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) { get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
is_character_based_ = false; is_character_based_ = false;
}
} }
}
} }
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) { double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_); lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob; double cond_prob;
lm::ngram::State state, tmp_state, out_state; lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin // avoid to inserting <s> in begin
model->NullContextWrite(&state); model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) { for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV // encounter OOV
if (word_index == 0) { if (word_index == 0) {
return OOV_SCORE; return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
} }
cond_prob = model->BaseScore(&state, word_index, &out_state); // return log10 prob
tmp_state = state; return cond_prob;
state = out_state;
out_state = tmp_state;
}
// return log10 prob
return cond_prob;
} }
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) { double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence; std::vector<std::string> sentence;
if (words.size() == 0) { if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) { for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN); sentence.push_back(START_TOKEN);
} }
} else { } else {
for (size_t i = 0; i < max_order_ - 1; ++i) { for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN); sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
} }
sentence.insert(sentence.end(), words.begin(), words.end()); sentence.push_back(END_TOKEN);
} return get_log_prob(sentence);
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
} }
double Scorer::get_log_prob(const std::vector<std::string>& words) { double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > max_order_); assert(words.size() > max_order_);
double score = 0.0; double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i, std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_); words.begin() + i + max_order_);
score += get_log_cond_prob(ngram); score += get_log_cond_prob(ngram);
} }
return score; return score;
} }
void Scorer::reset_params(float alpha, float beta) { void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
} }
std::string Scorer::vec2str(const std::vector<int>& input) { std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word; std::string word;
for (auto ind : input) { for (auto ind : input) {
word += char_list_[ind]; word += char_list_[ind];
} }
return word; return word;
} }
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) { std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {}; if (labels.empty()) return {};
std::string s = vec2str(labels); std::string s = vec2str(labels);
std::vector<std::string> words; std::vector<std::string> words;
if (is_character_based_) { if (is_character_based_) {
words = split_utf8_str(s); words = split_utf8_str(s);
} else { } else {
words = split_str(s, " "); words = split_str(s, " ");
} }
return words; return words;
} }
void Scorer::set_char_map(const std::vector<std::string>& char_list) { void Scorer::set_char_map(const std::vector<std::string>& char_list) {
char_list_ = char_list; char_list_ = char_list;
char_map_.clear(); char_map_.clear();
// Set the char map for the FST for spelling correction // Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) { for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") { if (char_list_[i] == " ") {
SPACE_ID_ = i; SPACE_ID_ = i;
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[char_list_[i]] = i + 1;
} }
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[char_list_[i]] = i + 1;
}
} }
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) { std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram; std::vector<std::string> ngram;
PathTrie* current_node = prefix; PathTrie* current_node = prefix;
PathTrie* new_node = nullptr; PathTrie* new_node = nullptr;
for (int order = 0; order < max_order_; order++) { for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec; std::vector<int> prefix_vec;
if (is_character_based_) { if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node; current_node = new_node;
} else { } else {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces current_node = new_node->parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
} }
std::reverse(ngram.begin(), ngram.end());
// reconstruct word return ngram;
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
} }
void Scorer::fill_dictionary(bool add_space) { void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary; fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie // For each unigram convert to ints and put in trie
int dict_size = 0; int dict_size = 0;
for (const auto& word : vocabulary_) { for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary( bool added = add_word_to_dictionary(
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
dict_size += added ? 1 : 0; dict_size += added ? 1 : 0;
} }
dict_size_ = dict_size; dict_size_ = dict_size;
/* Simplify FST /* Simplify FST
* This gets rid of "epsilon" transitions in the FST. * This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken. * These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but * Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST * can greatly increase the size of the FST
*/ */
fst::RmEpsilon(&dictionary); fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst; fst::StdVectorFst* new_dict = new fst::StdVectorFst;
/* This makes the FST deterministic, meaning for any string input there's /* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our * only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it. * dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state) * (lest we'd have to check for multiple transitions at each state)
*/ */
fst::Determinize(dictionary, new_dict); fst::Determinize(dictionary, new_dict);
/* Finds the simplest equivalent fst. This is unnecessary but decreases /* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary * memory usage of the dictionary
*/ */
fst::Minimize(new_dict); fst::Minimize(new_dict);
this->dictionary = new_dict; this->dictionary = new_dict;
} }
...@@ -34,14 +34,14 @@ const std::string END_TOKEN = "</s>"; ...@@ -34,14 +34,14 @@ const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model. // Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab { class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) { void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length())); vocabulary.push_back(std::string(str.data(), str.length()));
} }
std::vector<std::string> vocabulary; std::vector<std::string> vocabulary;
}; };
/* External scorer to query score for n-gram or sentence, including language /* External scorer to query score for n-gram or sentence, including language
...@@ -53,74 +53,74 @@ public: ...@@ -53,74 +53,74 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/ */
class Scorer { class Scorer {
public: public:
Scorer(double alpha, Scorer(double alpha,
double beta, double beta,
const std::string &lm_path, const std::string &lm_path,
const std::vector<std::string> &vocabulary); const std::vector<std::string> &vocabulary);
~Scorer(); ~Scorer();
double get_log_cond_prob(const std::vector<std::string> &words); double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string> &words); double get_sent_log_prob(const std::vector<std::string> &words);
// return the max order // return the max order
size_t get_max_order() const { return max_order_; } size_t get_max_order() const { return max_order_; }
// return the dictionary size of language model // return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; } size_t get_dict_size() const { return dict_size_; }
// retrun true if the language model is character based // retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; } bool is_character_based() const { return is_character_based_; }
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// make ngram for a given prefix // make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix); std::vector<std::string> make_ngram(PathTrie *prefix);
// trransform the labels in index to the vector of words (word based lm) or // trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm) // the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int> &labels);
// language model weight // language model weight
double alpha; double alpha;
// word insertion weight // word insertion weight
double beta; double beta;
// pointer to the dictionary of FST // pointer to the dictionary of FST
void *dictionary; void *dictionary;
protected: protected:
// necessary setup: load language model, set char map, fill FST's dictionary // necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path, void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list); const std::vector<std::string> &vocab_list);
// load language model from given path // load language model from given path
void load_lm(const std::string &lm_path); void load_lm(const std::string &lm_path);
// fill dictionary for FST // fill dictionary for FST
void fill_dictionary(bool add_space); void fill_dictionary(bool add_space);
// set char map // set char map
void set_char_map(const std::vector<std::string> &char_list); void set_char_map(const std::vector<std::string> &char_list);
double get_log_prob(const std::vector<std::string> &words); double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string // translate the vector in index to string
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int> &input);
private: private:
void *language_model_; void *language_model_;
bool is_character_based_; bool is_character_based_;
size_t max_order_; size_t max_order_;
size_t dict_size_; size_t dict_size_;
int SPACE_ID_; int SPACE_ID_;
std::vector<std::string> char_list_; std::vector<std::string> char_list_;
std::unordered_map<std::string, int> char_map_; std::unordered_map<std::string, int> char_map_;
std::vector<std::string> vocabulary_; std::vector<std::string> vocabulary_;
}; };
#endif // SCORER_H_ #endif // SCORER_H_
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
import pstats
import cProfile import cProfile
from io import StringIO
import getopt import getopt
import os import os
from os.path import dirname, join import pstats
import sys
from io import StringIO
from os.path import dirname
from os.path import join
import mmseg import mmseg
......
...@@ -94,7 +94,7 @@ ...@@ -94,7 +94,7 @@
* Update to the latest version of [Unihan Database](http://www.unicode.org/charts/unihan.html): * Update to the latest version of [Unihan Database](http://www.unicode.org/charts/unihan.html):
> Date: 2016-06-01 07:01:48 GMT [JHJ] > Date: 2016-06-01 07:01:48 GMT [JHJ]
> Unicode version: 9.0.0 > Unicode version: 9.0.0
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
[Unihan Database][unihan] 数据版本: [Unihan Database][unihan] 数据版本:
> Date: 2020-02-18 18:27:33 GMT [JHJ] > Date: 2020-02-18 18:27:33 GMT [JHJ]
> Unicode version: 13.0.0 > Unicode version: 13.0.0
* `kTGHZ2013.txt`: [Unihan Database][unihan][kTGHZ2013](http://www.unicode.org/reports/tr38/#kTGHZ2013) 部分的拼音数据(来源于《通用规范汉字字典》的拼音数据) * `kTGHZ2013.txt`: [Unihan Database][unihan][kTGHZ2013](http://www.unicode.org/reports/tr38/#kTGHZ2013) 部分的拼音数据(来源于《通用规范汉字字典》的拼音数据)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册