提交 f842c79a 编写于 作者: H Hui Zhang

format code

上级 e969a8ec
...@@ -47,7 +47,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -47,7 +47,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
// assign blank id // assign blank id
//size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
size_t blank_id = 0; size_t blank_id = 0;
// assign space id // assign space id
...@@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.push_back(&root); prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary); auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr); root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
...@@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
bool full_beam = false; bool full_beam = false;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(prefixes.begin(),
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score + min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
...@@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
// blank // blank
if (c == blank_id) { if (c == blank_id) {
prefix->log_prob_b_cur = prefix->log_prob_b_cur = log_sum_exp(
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue; continue;
} }
// repeated character // repeated character
if (c == prefix->character) { if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur =
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c); auto prefix_new = prefix->get_path_trie(c);
...@@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score); ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += ext_scorer->beta;
} }
...@@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (!prefix->is_empty() && prefix->character != space_id) { if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix); std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta; score += ext_scorer->beta;
prefix->score += score; prefix->score += score;
} }
...@@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable. // return order of decoding result. To delete when decoder gets stable.
...@@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
} }
......
...@@ -29,15 +29,17 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs( ...@@ -29,15 +29,17 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
// pruning of vacobulary // pruning of vacobulary
size_t cutoff_len = prob_step.size(); size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort( std::sort(prob_idx.begin(),
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) { for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
} }
} }
prob_idx = std::vector<std::pair<int, double>>( prob_idx = std::vector<std::pair<int, double>>(
...@@ -74,8 +76,8 @@ std::vector<std::pair<double, std::string>> get_beam_search_result( ...@@ -74,8 +76,8 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
for (size_t j = 0; j < output.size(); j++) { for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]]; output_str += vocabulary[output[j]];
} }
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc, std::pair<double, std::string> output_pair(
output_str); -space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair); output_vecs.emplace_back(output_pair);
} }
......
...@@ -134,7 +134,8 @@ void PathTrie::remove() { ...@@ -134,7 +134,8 @@ void PathTrie::remove() {
if (children_.size() == 0) { if (children_.size() == 0) {
auto child = parent->children_.begin(); auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end(); for (child = parent->children_.begin();
child != parent->children_.end();
++child) { ++child) {
if (child->first == character) { if (child->first == character) {
parent->children_.erase(child); parent->children_.erase(child);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
* finite-state transducer for spelling correction. * finite-state transducer for spelling correction.
*/ */
class PathTrie { class PathTrie {
public: public:
PathTrie(); PathTrie();
~PathTrie(); ~PathTrie();
...@@ -38,7 +38,8 @@ public: ...@@ -38,7 +38,8 @@ public:
PathTrie* get_path_vec(std::vector<int>& output); PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel // get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(
std::vector<int>& output,
int stop, int stop,
size_t max_steps = std::numeric_limits<size_t>::max()); size_t max_steps = std::numeric_limits<size_t>::max());
...@@ -64,7 +65,7 @@ public: ...@@ -64,7 +65,7 @@ public:
int character; int character;
PathTrie* parent; PathTrie* parent;
private: private:
int ROOT_; int ROOT_;
bool exists_; bool exists_;
bool has_dictionary_; bool has_dictionary_;
......
...@@ -34,7 +34,7 @@ const std::string END_TOKEN = "</s>"; ...@@ -34,7 +34,7 @@ const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model. // Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab { class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) { void Add(lm::WordIndex index, const StringPiece &str) {
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/ */
class Scorer { class Scorer {
public: public:
Scorer(double alpha, Scorer(double alpha,
double beta, double beta,
const std::string &lm_path, const std::string &lm_path,
...@@ -91,7 +91,7 @@ public: ...@@ -91,7 +91,7 @@ public:
// pointer to the dictionary of FST // pointer to the dictionary of FST
void *dictionary; void *dictionary;
protected: protected:
// necessary setup: load language model, set char map, fill FST's dictionary // necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path, void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list); const std::vector<std::string> &vocab_list);
...@@ -110,7 +110,7 @@ protected: ...@@ -110,7 +110,7 @@ protected:
// translate the vector in index to string // translate the vector in index to string
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int> &input);
private: private:
void *language_model_; void *language_model_;
bool is_character_based_; bool is_character_based_;
size_t max_order_; size_t max_order_;
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
import pstats
import cProfile import cProfile
from io import StringIO
import getopt import getopt
import os import os
from os.path import dirname, join import pstats
import sys
from io import StringIO
from os.path import dirname
from os.path import join
import mmseg import mmseg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册