提交 f842c79a 编写于 作者: H Hui Zhang

format code

上级 e969a8ec
......@@ -47,7 +47,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
// assign blank id
//size_t blank_id = vocabulary.size();
// size_t blank_id = vocabulary.size();
size_t blank_id = 0;
// assign space id
......@@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
......@@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
......@@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
......@@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
......@@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
......@@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
......@@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
......
......@@ -29,15 +29,17 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
std::sort(prob_idx.begin(),
prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
......@@ -74,8 +76,8 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
std::pair<double, std::string> output_pair(
-space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair);
}
......
......@@ -134,7 +134,8 @@ void PathTrie::remove() {
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
for (child = parent->children_.begin();
child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
......
......@@ -27,7 +27,7 @@
* finite-state transducer for spelling correction.
*/
class PathTrie {
public:
public:
PathTrie();
~PathTrie();
......@@ -38,7 +38,8 @@ public:
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output,
PathTrie* get_path_vec(
std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
......@@ -64,7 +65,7 @@ public:
int character;
PathTrie* parent;
private:
private:
int ROOT_;
bool exists_;
bool has_dictionary_;
......
......@@ -34,7 +34,7 @@ const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
public:
RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
......@@ -53,7 +53,7 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
public:
Scorer(double alpha,
double beta,
const std::string &lm_path,
......@@ -91,7 +91,7 @@ public:
// pointer to the dictionary of FST
void *dictionary;
protected:
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
......@@ -110,7 +110,7 @@ protected:
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input);
private:
private:
void *language_model_;
bool is_character_based_;
size_t max_order_;
......
#!/usr/bin/env python3
import sys
import pstats
import cProfile
from io import StringIO
import getopt
import os
from os.path import dirname, join
import pstats
import sys
from io import StringIO
from os.path import dirname
from os.path import join
import mmseg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册