From a24d0138d9c300024d040c735df1421d32e36ebb Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 17 Sep 2017 19:05:04 +0800 Subject: [PATCH] adjust scorer's init & add logging for scorer & separate long functions --- README.md | 1 - ...r_deprecated.py => decoders_deprecated.py} | 6 +- ...rer_deprecated.py => scorer_deprecated.py} | 0 ...coders.cpp => ctc_beam_search_decoder.cpp} | 164 +++--------------- ...c_decoders.h => ctc_beam_search_decoder.h} | 29 +--- decoders/swig/ctc_greedy_decoder.cpp | 45 +++++ decoders/swig/ctc_greedy_decoder.h | 20 +++ decoders/swig/decoder_utils.cpp | 65 +++++++ decoders/swig/decoder_utils.h | 39 +++-- decoders/swig/decoders.i | 6 +- decoders/swig/path_trie.h | 9 +- decoders/swig/scorer.cpp | 42 +++-- decoders/swig/scorer.h | 35 ++-- decoders/swig/setup.py | 13 +- decoders/swig/setup.sh | 2 +- decoders/swig_wrapper.py | 22 +-- examples/tiny/run_infer.sh | 6 +- examples/tiny/run_infer_golden.sh | 6 +- examples/tiny/run_test.sh | 6 +- examples/tiny/run_test_golden.sh | 6 +- infer.py | 1 + model_utils/model.py | 25 ++- test.py | 1 + 23 files changed, 310 insertions(+), 239 deletions(-) rename decoders/{decoder_deprecated.py => decoders_deprecated.py} (98%) rename decoders/{lm_scorer_deprecated.py => scorer_deprecated.py} (100%) rename decoders/swig/{ctc_decoders.cpp => ctc_beam_search_decoder.cpp} (55%) rename decoders/swig/{ctc_decoders.h => ctc_beam_search_decoder.h} (75%) create mode 100644 decoders/swig/ctc_greedy_decoder.cpp create mode 100644 decoders/swig/ctc_greedy_decoder.h diff --git a/README.md b/README.md index 75879971..9d9d4c77 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,6 @@ Data augmentation has often been a highly effective technique to boost the deep Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. -### Inference - Volume Perturbation - Speed Perturbation - Shifting Perturbation diff --git a/decoders/decoder_deprecated.py b/decoders/decoders_deprecated.py similarity index 98% rename from decoders/decoder_deprecated.py rename to decoders/decoders_deprecated.py index 64743163..17b28b0d 100644 --- a/decoders/decoder_deprecated.py +++ b/decoders/decoders_deprecated.py @@ -119,7 +119,7 @@ def ctc_beam_search_decoder(probs_seq, cutoff_len += 1 if cum_prob >= cutoff_prob: break - cutoff_len = min(cutoff_top_n, cutoff_top_n) + cutoff_len = min(cutoff_len, cutoff_top_n) prob_idx = prob_idx[0:cutoff_len] for l in prefix_set_prev: @@ -228,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, - cutoff_top_n, None, nproc) + args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n, + None, nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/decoders/lm_scorer_deprecated.py b/decoders/scorer_deprecated.py similarity index 100% rename from decoders/lm_scorer_deprecated.py rename to decoders/scorer_deprecated.py diff --git a/decoders/swig/ctc_decoders.cpp b/decoders/swig/ctc_beam_search_decoder.cpp similarity index 55% rename from decoders/swig/ctc_decoders.cpp rename to decoders/swig/ctc_beam_search_decoder.cpp index 35425fbc..36d16987 100644 --- a/decoders/swig/ctc_decoders.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -1,4 +1,4 @@ -#include "ctc_decoders.h" +#include "ctc_beam_search_decoder.h" #include #include @@ -9,59 +9,19 @@ #include "ThreadPool.h" #include "fst/fstlib.h" +#include "fst/log.h" #include "decoder_utils.h" #include "path_trie.h" -std::string ctc_greedy_decoder( - const std::vector> &probs_seq, - const std::vector &vocabulary) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - size_t blank_id = vocabulary.size(); - - std::vector max_idx_vec; - for (size_t i = 0; i < num_time_steps; ++i) { - double max_prob = 0.0; - size_t max_idx = 0; - for (size_t j = 0; j < probs_seq[i].size(); j++) { - if (max_prob < probs_seq[i][j]) { - max_idx = j; - max_prob = probs_seq[i][j]; - } - } - max_idx_vec.push_back(max_idx); - } - - std::vector idx_vec; - for (size_t i = 0; i < max_idx_vec.size(); ++i) { - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { - idx_vec.push_back(max_idx_vec[i]); - } - } - - std::string best_path_result; - for (size_t i = 0; i < idx_vec.size(); ++i) { - if (idx_vec[i] != blank_id) { - best_path_result += vocabulary[idx_vec[i]]; - } - } - return best_path_result; -} +using FSTMATCH = fst::SortedMatcher; std::vector> ctc_beam_search_decoder( const std::vector> &probs_seq, - const size_t beam_size, + size_t beam_size, std::vector vocabulary, - const double cutoff_prob, - const size_t cutoff_top_n, + double cutoff_prob, + size_t cutoff_top_n, Scorer *ext_scorer) { // dimension check size_t num_time_steps = probs_seq.size(); @@ -80,7 +40,7 @@ std::vector> ctc_beam_search_decoder( std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it - vocabulary.begin(); // if no space in vocabulary - if (space_id >= vocabulary.size()) { + if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } @@ -90,30 +50,17 @@ std::vector> ctc_beam_search_decoder( std::vector prefixes; prefixes.push_back(&root); - if (ext_scorer != nullptr) { - if (ext_scorer->is_char_map_empty()) { - ext_scorer->set_char_map(vocabulary); - } - if (!ext_scorer->is_character_based()) { - if (ext_scorer->dictionary == nullptr) { - // fill dictionary for fst with space - ext_scorer->fill_dictionary(true); - } - auto fst_dict = static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); } // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; time_step++) { - std::vector prob = probs_seq[time_step]; - std::vector> prob_idx; - for (size_t i = 0; i < prob.size(); ++i) { - prob_idx.push_back(std::pair(i, prob[i])); - } + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; float min_cutoff = -NUM_FLT_INF; bool full_beam = false; @@ -121,43 +68,20 @@ std::vector> ctc_beam_search_decoder( size_t num_prefixes = std::min(prefixes.size(), beam_size); std::sort( prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); full_beam = (num_prefixes == beam_size); } - // pruning of vacobulary - size_t cutoff_len = prob.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { - std::sort( - prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob += prob_idx[i].second; - cutoff_len += 1; - if (cum_prob >= cutoff_prob) break; - } - } - cutoff_len = std::min(cutoff_len, cutoff_top_n); - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } - + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); // loop over chars for (size_t index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; - float log_prob_c = log_prob_idx[index].second; + auto log_prob_c = log_prob_idx[index].second; for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } @@ -189,7 +113,6 @@ std::vector> ctc_beam_search_decoder( if (ext_scorer != nullptr && (c == space_id || ext_scorer->is_character_based())) { PathTrie *prefix_toscore = nullptr; - // skip scoring the space if (ext_scorer->is_character_based()) { prefix_toscore = prefix_new; @@ -201,7 +124,6 @@ std::vector> ctc_beam_search_decoder( std::vector ngram; ngram = ext_scorer->make_ngram(prefix_toscore); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - log_p += score; log_p += ext_scorer->beta; } @@ -221,57 +143,33 @@ std::vector> ctc_beam_search_decoder( prefixes.begin() + beam_size, prefixes.end(), prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { prefixes[i]->remove(); } } } // end of loop over time - // compute aproximate ctc score as the return score + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { std::vector output; prefixes[i]->get_path_vec(output); - size_t prefix_length = output.size(); + auto prefix_length = output.size(); auto words = ext_scorer->split_labels(output); // remove word insert approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; // remove language model weight: approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; } - prefixes[i]->approx_ctc = approx_ctc; } - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (size_t j = 0; j < output.size(); j++) { - output_str += vocabulary[output[j]]; - } - std::pair output_pair(-space_prefixes[i]->approx_ctc, - output_str); - output_vecs.emplace_back(output_pair); - } - - return output_vecs; + return get_beam_search_result(prefixes, vocabulary, beam_size); } + std::vector>> ctc_beam_search_decoder_batch( const std::vector>> &probs_split, @@ -287,18 +185,6 @@ ctc_beam_search_decoder_batch( // number of samples size_t batch_size = probs_split.size(); - // scorer filling up - if (ext_scorer != nullptr) { - if (ext_scorer->is_char_map_empty()) { - ext_scorer->set_char_map(vocabulary); - } - if (!ext_scorer->is_character_based() && - ext_scorer->dictionary == nullptr) { - // init dictionary - ext_scorer->fill_dictionary(true); - } - } - // enqueue the tasks of decoding std::vector>>> res; for (size_t i = 0; i < batch_size; ++i) { diff --git a/decoders/swig/ctc_decoders.h b/decoders/swig/ctc_beam_search_decoder.h similarity index 75% rename from decoders/swig/ctc_decoders.h rename to decoders/swig/ctc_beam_search_decoder.h index 6384c8a8..c800384e 100644 --- a/decoders/swig/ctc_decoders.h +++ b/decoders/swig/ctc_beam_search_decoder.h @@ -7,19 +7,6 @@ #include "scorer.h" -/* CTC Best Path Decoder - * - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * Return: - * The decoding result in string - */ -std::string ctc_greedy_decoder( - const std::vector> &probs_seq, - const std::vector &vocabulary); - /* CTC Beam Search Decoder * Parameters: @@ -38,11 +25,11 @@ std::string ctc_greedy_decoder( */ std::vector> ctc_beam_search_decoder( const std::vector> &probs_seq, - const size_t beam_size, + size_t beam_size, std::vector vocabulary, - const double cutoff_prob = 1.0, - const size_t cutoff_top_n = 40, - Scorer *ext_scorer = NULL); + double cutoff_prob = 1.0, + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr); /* CTC Beam Search Decoder for batch data @@ -65,11 +52,11 @@ std::vector> ctc_beam_search_decoder( std::vector>> ctc_beam_search_decoder_batch( const std::vector>> &probs_split, - const size_t beam_size, + size_t beam_size, const std::vector &vocabulary, - const size_t num_processes, + size_t num_processes, double cutoff_prob = 1.0, - const size_t cutoff_top_n = 40, - Scorer *ext_scorer = NULL); + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/decoders/swig/ctc_greedy_decoder.cpp b/decoders/swig/ctc_greedy_decoder.cpp new file mode 100644 index 00000000..c4c94539 --- /dev/null +++ b/decoders/swig/ctc_greedy_decoder.cpp @@ -0,0 +1,45 @@ +#include "ctc_greedy_decoder.h" +#include "decoder_utils.h" + +std::string ctc_greedy_decoder( + const std::vector> &probs_seq, + const std::vector &vocabulary) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + vocabulary.size() + 1, + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + size_t blank_id = vocabulary.size(); + + std::vector max_idx_vec(num_time_steps, 0); + std::vector idx_vec; + for (size_t i = 0; i < num_time_steps; ++i) { + double max_prob = 0.0; + size_t max_idx = 0; + const std::vector &probs_step = probs_seq[i]; + for (size_t j = 0; j < probs_step.size(); ++j) { + if (max_prob < probs_step[j]) { + max_idx = j; + max_prob = probs_step[j]; + } + } + // id with maximum probability in current step + max_idx_vec[i] = max_idx; + // deduplicate + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { + idx_vec.push_back(max_idx_vec[i]); + } + } + + std::string best_path_result; + for (size_t i = 0; i < idx_vec.size(); ++i) { + if (idx_vec[i] != blank_id) { + best_path_result += vocabulary[idx_vec[i]]; + } + } + return best_path_result; +} diff --git a/decoders/swig/ctc_greedy_decoder.h b/decoders/swig/ctc_greedy_decoder.h new file mode 100644 index 00000000..043742f2 --- /dev/null +++ b/decoders/swig/ctc_greedy_decoder.h @@ -0,0 +1,20 @@ +#ifndef CTC_GREEDY_DECODER_H +#define CTC_GREEDY_DECODER_H + +#include +#include + +/* CTC Greedy (Best Path) Decoder + * + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * Return: + * The decoding result in string + */ +std::string ctc_greedy_decoder( + const std::vector> &probs_seq, + const std::vector &vocabulary); + +#endif // CTC_GREEDY_DECODER_H diff --git a/decoders/swig/decoder_utils.cpp b/decoders/swig/decoder_utils.cpp index 989b067e..665fcc22 100644 --- a/decoders/swig/decoder_utils.cpp +++ b/decoders/swig/decoder_utils.cpp @@ -4,6 +4,71 @@ #include #include +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n) { + std::vector> prob_idx; + for (size_t i = 0; i < prob_step.size(); ++i) { + prob_idx.push_back(std::pair(i, prob_step[i])); + } + // pruning of vacobulary + size_t cutoff_len = prob_step.size(); + if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { + std::sort( + prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (size_t i = 0; i < prob_idx.size(); ++i) { + cum_prob += prob_idx[i].second; + cutoff_len += 1; + if (cum_prob >= cutoff_prob) break; + } + } + cutoff_len = std::min(cutoff_len, cutoff_top_n); + prob_idx = std::vector>( + prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (size_t i = 0; i < cutoff_len; ++i) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); + } + return log_prob_idx; +} + + +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size) { + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + space_prefixes.push_back(prefixes[i]); + } + } + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (size_t j = 0; j < output.size(); j++) { + output_str += vocabulary[output[j]]; + } + std::pair output_pair(-space_prefixes[i]->approx_ctc, + output_str); + output_vecs.emplace_back(output_pair); + } + + return output_vecs; +} + size_t get_utf8_str_len(const std::string &str) { size_t str_len = 0; for (char c : str) { diff --git a/decoders/swig/decoder_utils.h b/decoders/swig/decoder_utils.h index 015646dd..932ffb12 100644 --- a/decoders/swig/decoder_utils.h +++ b/decoders/swig/decoder_utils.h @@ -3,25 +3,26 @@ #include #include "path_trie.h" +#include "fst/log.h" const float NUM_FLT_INF = std::numeric_limits::max(); const float NUM_FLT_MIN = std::numeric_limits::min(); -// check if __A == _B -#define VALID_CHECK_EQ(__A, __B, __ERR) \ - if ((__A) != (__B)) { \ - std::ostringstream str; \ - str << (__A) << " != " << (__B) << ", "; \ - throw std::runtime_error(str.str() + __ERR); \ +// inline function for validation check +inline void check( + bool x, const char *expr, const char *file, int line, const char *err) { + if (!x) { + std::cout << "[" << file << ":" << line << "] "; + LOG(FATAL) << "\"" << expr << "\" check failed. " << err; } +} + +#define VALID_CHECK(x, info) \ + check(static_cast(x), #x, __FILE__, __LINE__, info) +#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) +#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) +#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) -// check if __A > __B -#define VALID_CHECK_GT(__A, __B, __ERR) \ - if ((__A) <= (__B)) { \ - std::ostringstream str; \ - str << (__A) << " <= " << (__B) << ", "; \ - throw std::runtime_error(str.str() + __ERR); \ - } // Function template for comparing two pairs template @@ -47,6 +48,18 @@ T log_sum_exp(const T &x, const T &y) { return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; } +// Get pruned probability vector for each time step's beam search +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n); + +// Get beam search result from prefixes in trie tree +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size); + // Functor for prefix comparsion bool prefix_compare(const PathTrie *x, const PathTrie *y); diff --git a/decoders/swig/decoders.i b/decoders/swig/decoders.i index 8059199d..4227d4a3 100644 --- a/decoders/swig/decoders.i +++ b/decoders/swig/decoders.i @@ -1,7 +1,8 @@ %module swig_decoders %{ #include "scorer.h" -#include "ctc_decoders.h" +#include "ctc_greedy_decoder.h" +#include "ctc_beam_search_decoder.h" #include "decoder_utils.h" %} @@ -28,4 +29,5 @@ namespace std { %template(DoubleStringPairCompFirstRev) pair_comp_first_rev; %include "scorer.h" -%include "ctc_decoders.h" +%include "ctc_greedy_decoder.h" +%include "ctc_beam_search_decoder.h" diff --git a/decoders/swig/path_trie.h b/decoders/swig/path_trie.h index ddeccd91..b4f5bc4b 100644 --- a/decoders/swig/path_trie.h +++ b/decoders/swig/path_trie.h @@ -1,14 +1,13 @@ #ifndef PATH_TRIE_H #define PATH_TRIE_H -#pragma once -#include + #include #include #include #include #include -using FSTMATCH = fst::SortedMatcher; +#include "fst/fstlib.h" /* Trie tree for prefix storing and manipulating, with a dictionary in * finite-state transducer for spelling correction. @@ -35,7 +34,7 @@ public: // set dictionary for FST void set_dictionary(fst::StdVectorFst* dictionary); - void set_matcher(std::shared_ptr matcher); + void set_matcher(std::shared_ptr>); bool is_empty() { return _ROOT == character; } @@ -62,7 +61,7 @@ private: fst::StdVectorFst* _dictionary; fst::StdVectorFst::StateId _dictionary_state; // true if finding ars in FST - std::shared_ptr _matcher; + std::shared_ptr> _matcher; }; #endif // PATH_TRIE_H diff --git a/decoders/swig/scorer.cpp b/decoders/swig/scorer.cpp index 75919c3c..6b280344 100644 --- a/decoders/swig/scorer.cpp +++ b/decoders/swig/scorer.cpp @@ -13,29 +13,47 @@ using namespace lm::ngram; -Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { +Scorer::Scorer(double alpha, + double beta, + const std::string& lm_path, + const std::vector& vocab_list) { this->alpha = alpha; this->beta = beta; _is_character_based = true; _language_model = nullptr; dictionary = nullptr; _max_order = 0; + _dict_size = 0; _SPACE_ID = -1; - // load language model - load_LM(lm_path.c_str()); + + setup(lm_path, vocab_list); } Scorer::~Scorer() { - if (_language_model != nullptr) + if (_language_model != nullptr) { delete static_cast(_language_model); - if (dictionary != nullptr) delete static_cast(dictionary); + } + if (dictionary != nullptr) { + delete static_cast(dictionary); + } } -void Scorer::load_LM(const char* filename) { - if (access(filename, F_OK) != 0) { - std::cerr << "Invalid language model file !!!" << std::endl; - exit(1); +void Scorer::setup(const std::string& lm_path, + const std::vector& vocab_list) { + // load language model + load_lm(lm_path); + // set char map for scorer + set_char_map(vocab_list); + // fill the dictionary for FST + if (!is_character_based()) { + fill_dictionary(true); } +} + +void Scorer::load_lm(const std::string& lm_path) { + const char* filename = lm_path.c_str(); + VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); + RetriveStrEnumerateVocab enumerate; lm::ngram::Config config; config.enumerate_vocab = &enumerate; @@ -180,14 +198,14 @@ void Scorer::fill_dictionary(bool add_space) { } // For each unigram convert to ints and put in trie - int vocab_size = 0; + int dict_size = 0; for (const auto& word : _vocabulary) { bool added = add_word_to_dictionary( word, char_map, add_space, _SPACE_ID, &dictionary); - vocab_size += added ? 1 : 0; + dict_size += added ? 1 : 0; } - std::cerr << "Vocab Size " << vocab_size << std::endl; + _dict_size = dict_size; /* Simplify FST diff --git a/decoders/swig/scorer.h b/decoders/swig/scorer.h index 1b4857e3..72544da7 100644 --- a/decoders/swig/scorer.h +++ b/decoders/swig/scorer.h @@ -40,31 +40,32 @@ public: */ class Scorer { public: - Scorer(double alpha, double beta, const std::string &lm_path); + Scorer(double alpha, + double beta, + const std::string &lm_path, + const std::vector &vocabulary); ~Scorer(); double get_log_cond_prob(const std::vector &words); double get_sent_log_prob(const std::vector &words); - size_t get_max_order() { return _max_order; } + size_t get_max_order() const { return _max_order; } - bool is_char_map_empty() { return _char_map.size() == 0; } + size_t get_dict_size() const { return _dict_size; } - bool is_character_based() { return _is_character_based; } + bool is_char_map_empty() const { return _char_map.size() == 0; } + + bool is_character_based() const { return _is_character_based; } // reset params alpha & beta void reset_params(float alpha, float beta); - // make ngram + // make ngram for a given prefix std::vector make_ngram(PathTrie *prefix); - // fill dictionary for fst - void fill_dictionary(bool add_space); - - // set char map - void set_char_map(const std::vector &char_list); - + // trransform the labels in index to the vector of words (word based lm) or + // the vector of characters (character based lm) std::vector split_labels(const std::vector &labels); // expose to decoder @@ -75,7 +76,16 @@ public: void *dictionary; protected: - void load_LM(const char *filename); + void setup(const std::string &lm_path, + const std::vector &vocab_list); + + void load_lm(const std::string &lm_path); + + // fill dictionary for fst + void fill_dictionary(bool add_space); + + // set char map + void set_char_map(const std::vector &char_list); double get_log_prob(const std::vector &words); @@ -85,6 +95,7 @@ private: void *_language_model; bool _is_character_based; size_t _max_order; + size_t _dict_size; int _SPACE_ID; std::vector _char_list; diff --git a/decoders/swig/setup.py b/decoders/swig/setup.py index 7a4b7e02..8af9ff30 100644 --- a/decoders/swig/setup.py +++ b/decoders/swig/setup.py @@ -70,8 +70,11 @@ FILES = glob.glob('kenlm/util/*.cc') \ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') +# FILES + glob.glob('glog/src/*.cc') FILES = [ - fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) + fn for fn in FILES + if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( + 'unittest.cc')) ] LIBS = ['stdc++'] @@ -99,7 +102,13 @@ decoders_module = [ name='_swig_decoders', sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), language='c++', - include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'], + include_dirs=[ + '.', + 'kenlm', + 'openfst-1.6.3/src/include', + 'ThreadPool', + #'glog/src' + ], libraries=LIBS, extra_compile_args=ARGS) ] diff --git a/decoders/swig/setup.sh b/decoders/swig/setup.sh index 069f51d6..78ae2b20 100644 --- a/decoders/swig/setup.sh +++ b/decoders/swig/setup.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash if [ ! -d kenlm ]; then git clone https://github.com/luotao1/kenlm.git diff --git a/decoders/swig_wrapper.py b/decoders/swig_wrapper.py index 54ed249f..5ebcd133 100644 --- a/decoders/swig_wrapper.py +++ b/decoders/swig_wrapper.py @@ -13,14 +13,14 @@ class Scorer(swig_decoders.Scorer): language model when alpha = 0. :type alpha: float :param beta: Parameter associated with word count. Don't use word - count when beta = 0. + count when beta = 0. :type beta: float :model_path: Path to load language model. :type model_path: basestring """ - def __init__(self, alpha, beta, model_path): - swig_decoders.Scorer.__init__(self, alpha, beta, model_path) + def __init__(self, alpha, beta, model_path, vocabulary): + swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) def ctc_greedy_decoder(probs_seq, vocabulary): @@ -58,12 +58,12 @@ def ctc_beam_search_decoder(probs_seq, default 1.0, no pruning. :type cutoff_prob: float :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. + characters with highest probs in vocabulary will be + used in beam search, default 40. :type cutoff_top_n: int :param ext_scoring_func: External scoring function for - partially decoded sentence, e.g. word count - or language model. + partially decoded sentence, e.g. word count + or language model. :type external_scoring_func: callable :return: List of tuples of log probability and sentence as decoding results, in descending order of the probability. @@ -96,14 +96,14 @@ def ctc_beam_search_decoder_batch(probs_split, default 1.0, no pruning. :type cutoff_prob: float :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. + characters with highest probs in vocabulary will be + used in beam search, default 40. :type cutoff_top_n: int :param num_processes: Number of parallel processes. :type num_processes: int :param ext_scoring_func: External scoring function for - partially decoded sentence, e.g. word count - or language model. + partially decoded sentence, e.g. word count + or language model. :type external_scoring_function: callable :return: List of tuples of log probability and sentence as decoding results, in descending order of the probability. diff --git a/examples/tiny/run_infer.sh b/examples/tiny/run_infer.sh index 1d33bfbb..1e90f608 100644 --- a/examples/tiny/run_infer.sh +++ b/examples/tiny/run_infer.sh @@ -21,9 +21,9 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/examples/tiny/run_infer_golden.sh b/examples/tiny/run_infer_golden.sh index 32e9d862..40bb3033 100644 --- a/examples/tiny/run_infer_golden.sh +++ b/examples/tiny/run_infer_golden.sh @@ -30,9 +30,9 @@ python -u infer.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/examples/tiny/run_test.sh b/examples/tiny/run_test.sh index f9c3cc11..868a045f 100644 --- a/examples/tiny/run_test.sh +++ b/examples/tiny/run_test.sh @@ -22,9 +22,9 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/examples/tiny/run_test_golden.sh b/examples/tiny/run_test_golden.sh index 080c3c06..1a4731dd 100644 --- a/examples/tiny/run_test_golden.sh +++ b/examples/tiny/run_test_golden.sh @@ -31,9 +31,9 @@ python -u test.py \ --num_conv_layers=2 \ --num_rnn_layers=3 \ --rnn_layer_size=2048 \ ---alpha=0.36 \ ---beta=0.25 \ ---cutoff_prob=0.99 \ +--alpha=2.15 \ +--beta=0.35 \ +--cutoff_prob=1.0 \ --use_gru=False \ --use_gpu=True \ --share_rnn_weights=True \ diff --git a/infer.py b/infer.py index 1064fd25..e635f6d0 100644 --- a/infer.py +++ b/infer.py @@ -112,6 +112,7 @@ def infer(): print("Current error rate [%s] = %f" % (args.error_rate_type, error_rate_func(target, result))) + ds2_model.logger.info("finish inference") def main(): print_arguments(args) diff --git a/model_utils/model.py b/model_utils/model.py index 4f5021a6..66b161c3 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -6,6 +6,7 @@ from __future__ import print_function import sys import os import time +import logging import gzip import paddle.v2 as paddle from decoders.swig_wrapper import Scorer @@ -13,6 +14,9 @@ from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_beam_search_decoder_batch from model_utils.network import deep_speech_v2_network +logging.basicConfig( + format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') + class DeepSpeech2Model(object): """DeepSpeech2Model class. @@ -43,6 +47,8 @@ class DeepSpeech2Model(object): self._inferer = None self._loss_inferer = None self._ext_scorer = None + self.logger = logging.getLogger("") + self.logger.setLevel(level=logging.INFO) def train(self, train_batch_reader, @@ -204,16 +210,25 @@ class DeepSpeech2Model(object): elif decoding_method == "ctc_beam_search": # initialize external scorer if self._ext_scorer == None: - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path) self._loaded_lm_path = language_model_path - self._ext_scorer.set_char_map(vocab_list) - if (not self._ext_scorer.is_character_based()): - self._ext_scorer.fill_dictionary(True) + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer. Start decoding ...") else: self._ext_scorer.reset_params(beam_alpha, beam_beta) assert self._loaded_lm_path == language_model_path # beam search decode + num_processes = min(num_processes, len(probs_split)) beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=vocab_list, diff --git a/test.py b/test.py index c564bb85..40f0795a 100644 --- a/test.py +++ b/test.py @@ -115,6 +115,7 @@ def evaluate(): print("Final error rate [%s] (%d/%d) = %f" % (args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) + ds2_model.logger.info("finish evaluation") def main(): print_arguments(args) -- GitLab