提交 e6740af4 编写于 作者: Y Yibing Liu

adjust scorer's init & add logging for scorer & separate long functions

上级 15728d04
...@@ -176,7 +176,6 @@ Data augmentation has often been a highly effective technique to boost the deep ...@@ -176,7 +176,6 @@ Data augmentation has often been a highly effective technique to boost the deep
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.
### Inference
- Volume Perturbation - Volume Perturbation
- Speed Perturbation - Speed Perturbation
- Shifting Perturbation - Shifting Perturbation
......
...@@ -119,7 +119,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -119,7 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len += 1 cutoff_len += 1
if cum_prob >= cutoff_prob: if cum_prob >= cutoff_prob:
break break
cutoff_len = min(cutoff_top_n, cutoff_top_n) cutoff_len = min(cutoff_len, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len] prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev: for l in prefix_set_prev:
...@@ -228,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -228,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes) pool = multiprocessing.Pool(processes=num_processes)
results = [] results = []
for i, probs_list in enumerate(probs_split): for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
cutoff_top_n, None, nproc) None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args)) results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close() pool.close()
......
#include "ctc_decoders.h" #include "ctc_beam_search_decoder.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -9,59 +9,19 @@ ...@@ -9,59 +9,19 @@
#include "ThreadPool.h" #include "ThreadPool.h"
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "fst/log.h"
#include "decoder_utils.h" #include "decoder_utils.h"
#include "path_trie.h" #include "path_trie.h"
std::string ctc_greedy_decoder( using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
for (size_t j = 0; j < probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) {
max_idx = j;
max_prob = probs_seq[i][j];
}
}
max_idx_vec.push_back(max_idx);
}
std::vector<size_t> idx_vec;
for (size_t i = 0; i < max_idx_vec.size(); ++i) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
}
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const size_t beam_size, size_t beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
const double cutoff_prob, double cutoff_prob,
const size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
...@@ -80,7 +40,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -80,7 +40,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::find(vocabulary.begin(), vocabulary.end(), " "); std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
// if no space in vocabulary // if no space in vocabulary
if (space_id >= vocabulary.size()) { if ((size_t)space_id >= vocabulary.size()) {
space_id = -2; space_id = -2;
} }
...@@ -90,30 +50,17 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -90,30 +50,17 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<PathTrie *> prefixes; std::vector<PathTrie *> prefixes;
prefixes.push_back(&root); prefixes.push_back(&root);
if (ext_scorer != nullptr) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer->is_char_map_empty()) { auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
ext_scorer->set_char_map(vocabulary); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
} root.set_dictionary(dict_ptr);
if (!ext_scorer->is_character_based()) { auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
if (ext_scorer->dictionary == nullptr) { root.set_matcher(matcher);
// fill dictionary for fst with space
ext_scorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
} }
// prefix search over time // prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
std::vector<double> prob = probs_seq[time_step]; auto &prob = probs_seq[time_step];
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
}
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
...@@ -121,43 +68,20 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -121,43 +68,20 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) - min_cutoff = prefixes[num_prefixes - 1]->score +
std::max(0.0, ext_scorer->beta); std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
// pruning of vacobulary std::vector<std::pair<size_t, float>> log_prob_idx =
size_t cutoff_len = prob.size(); get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
}
cutoff_len = std::min(cutoff_len, cutoff_top_n);
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
// loop over chars // loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) { for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second; auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break; break;
} }
...@@ -189,7 +113,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -189,7 +113,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (ext_scorer != nullptr && if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) { (c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_toscore = nullptr; PathTrie *prefix_toscore = nullptr;
// skip scoring the space // skip scoring the space
if (ext_scorer->is_character_based()) { if (ext_scorer->is_character_based()) {
prefix_toscore = prefix_new; prefix_toscore = prefix_new;
...@@ -201,7 +124,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -201,7 +124,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_toscore); ngram = ext_scorer->make_ngram(prefix_toscore);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += ext_scorer->beta;
} }
...@@ -221,57 +143,33 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -221,57 +143,33 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.begin() + beam_size, prefixes.begin() + beam_size,
prefixes.end(), prefixes.end(),
prefix_compare); prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) { for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove(); prefixes[i]->remove();
} }
} }
} // end of loop over time } // end of loop over time
// compute aproximate ctc score as the return score // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score; double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
std::vector<int> output; std::vector<int> output;
prefixes[i]->get_path_vec(output); prefixes[i]->get_path_vec(output);
size_t prefix_length = output.size(); auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output); auto words = ext_scorer->split_labels(output);
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
} }
// allow for the post processing return get_beam_search_result(prefixes, vocabulary, beam_size);
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
output_vecs.emplace_back(output_pair);
}
return output_vecs;
} }
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
...@@ -287,18 +185,6 @@ ctc_beam_search_decoder_batch( ...@@ -287,18 +185,6 @@ ctc_beam_search_decoder_batch(
// number of samples // number of samples
size_t batch_size = probs_split.size(); size_t batch_size = probs_split.size();
// scorer filling up
if (ext_scorer != nullptr) {
if (ext_scorer->is_char_map_empty()) {
ext_scorer->set_char_map(vocabulary);
}
if (!ext_scorer->is_character_based() &&
ext_scorer->dictionary == nullptr) {
// init dictionary
ext_scorer->fill_dictionary(true);
}
}
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
......
...@@ -7,19 +7,6 @@ ...@@ -7,19 +7,6 @@
#include "scorer.h" #include "scorer.h"
/* CTC Best Path Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary);
/* CTC Beam Search Decoder /* CTC Beam Search Decoder
* Parameters: * Parameters:
...@@ -38,11 +25,11 @@ std::string ctc_greedy_decoder( ...@@ -38,11 +25,11 @@ std::string ctc_greedy_decoder(
*/ */
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const size_t beam_size, size_t beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
const double cutoff_prob = 1.0, double cutoff_prob = 1.0,
const size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = NULL); Scorer *ext_scorer = nullptr);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
...@@ -65,11 +52,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -65,11 +52,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
const size_t beam_size, size_t beam_size,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
const size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
const size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = NULL); Scorer *ext_scorer = nullptr);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
const std::vector<double> &probs_step = probs_seq[i];
for (size_t j = 0; j < probs_step.size(); ++j) {
if (max_prob < probs_step[j]) {
max_idx = j;
max_prob = probs_step[j];
}
}
// id with maximum probability in current step
max_idx_vec[i] = max_idx;
// deduplicate
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
}
#ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H
#include <string>
#include <vector>
/* CTC Greedy (Best Path) Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary);
#endif // CTC_GREEDY_DECODER_H
...@@ -4,6 +4,71 @@ ...@@ -4,6 +4,71 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
}
cutoff_len = std::min(cutoff_len, cutoff_top_n);
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
return log_prob_idx;
}
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
output_vecs.emplace_back(output_pair);
}
return output_vecs;
}
size_t get_utf8_str_len(const std::string &str) { size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0; size_t str_len = 0;
for (char c : str) { for (char c : str) {
......
...@@ -3,25 +3,26 @@ ...@@ -3,25 +3,26 @@
#include <utility> #include <utility>
#include "path_trie.h" #include "path_trie.h"
#include "fst/log.h"
const float NUM_FLT_INF = std::numeric_limits<float>::max(); const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min(); const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// check if __A == _B // inline function for validation check
#define VALID_CHECK_EQ(__A, __B, __ERR) \ inline void check(
if ((__A) != (__B)) { \ bool x, const char *expr, const char *file, int line, const char *err) {
std::ostringstream str; \ if (!x) {
str << (__A) << " != " << (__B) << ", "; \ std::cout << "[" << file << ":" << line << "] ";
throw std::runtime_error(str.str() + __ERR); \ LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
} }
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// check if __A > __B
#define VALID_CHECK_GT(__A, __B, __ERR) \
if ((__A) <= (__B)) { \
std::ostringstream str; \
str << (__A) << " <= " << (__B) << ", "; \
throw std::runtime_error(str.str() + __ERR); \
}
// Function template for comparing two pairs // Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
...@@ -47,6 +48,18 @@ T log_sum_exp(const T &x, const T &y) { ...@@ -47,6 +48,18 @@ T log_sum_exp(const T &x, const T &y) {
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
} }
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size);
// Functor for prefix comparsion // Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y); bool prefix_compare(const PathTrie *x, const PathTrie *y);
......
%module swig_decoders %module swig_decoders
%{ %{
#include "scorer.h" #include "scorer.h"
#include "ctc_decoders.h" #include "ctc_greedy_decoder.h"
#include "ctc_beam_search_decoder.h"
#include "decoder_utils.h" #include "decoder_utils.h"
%} %}
...@@ -28,4 +29,5 @@ namespace std { ...@@ -28,4 +29,5 @@ namespace std {
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>; %template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
%include "scorer.h" %include "scorer.h"
%include "ctc_decoders.h" %include "ctc_greedy_decoder.h"
%include "ctc_beam_search_decoder.h"
#ifndef PATH_TRIE_H #ifndef PATH_TRIE_H
#define PATH_TRIE_H #define PATH_TRIE_H
#pragma once
#include <fst/fstlib.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; #include "fst/fstlib.h"
/* Trie tree for prefix storing and manipulating, with a dictionary in /* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction. * finite-state transducer for spelling correction.
...@@ -35,7 +34,7 @@ public: ...@@ -35,7 +34,7 @@ public:
// set dictionary for FST // set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary); void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<FSTMATCH> matcher); void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return _ROOT == character; } bool is_empty() { return _ROOT == character; }
...@@ -62,7 +61,7 @@ private: ...@@ -62,7 +61,7 @@ private:
fst::StdVectorFst* _dictionary; fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state; fst::StdVectorFst::StateId _dictionary_state;
// true if finding ars in FST // true if finding ars in FST
std::shared_ptr<FSTMATCH> _matcher; std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> _matcher;
}; };
#endif // PATH_TRIE_H #endif // PATH_TRIE_H
...@@ -13,29 +13,47 @@ ...@@ -13,29 +13,47 @@
using namespace lm::ngram; using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
_is_character_based = true; _is_character_based = true;
_language_model = nullptr; _language_model = nullptr;
dictionary = nullptr; dictionary = nullptr;
_max_order = 0; _max_order = 0;
_dict_size = 0;
_SPACE_ID = -1; _SPACE_ID = -1;
// load language model
load_LM(lm_path.c_str()); setup(lm_path, vocab_list);
} }
Scorer::~Scorer() { Scorer::~Scorer() {
if (_language_model != nullptr) if (_language_model != nullptr) {
delete static_cast<lm::base::Model*>(_language_model); delete static_cast<lm::base::Model*>(_language_model);
if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary); }
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
}
} }
void Scorer::load_LM(const char* filename) { void Scorer::setup(const std::string& lm_path,
if (access(filename, F_OK) != 0) { const std::vector<std::string>& vocab_list) {
std::cerr << "Invalid language model file !!!" << std::endl; // load language model
exit(1); load_lm(lm_path);
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
fill_dictionary(true);
} }
}
void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate; RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config; lm::ngram::Config config;
config.enumerate_vocab = &enumerate; config.enumerate_vocab = &enumerate;
...@@ -180,14 +198,14 @@ void Scorer::fill_dictionary(bool add_space) { ...@@ -180,14 +198,14 @@ void Scorer::fill_dictionary(bool add_space) {
} }
// For each unigram convert to ints and put in trie // For each unigram convert to ints and put in trie
int vocab_size = 0; int dict_size = 0;
for (const auto& word : _vocabulary) { for (const auto& word : _vocabulary) {
bool added = add_word_to_dictionary( bool added = add_word_to_dictionary(
word, char_map, add_space, _SPACE_ID, &dictionary); word, char_map, add_space, _SPACE_ID, &dictionary);
vocab_size += added ? 1 : 0; dict_size += added ? 1 : 0;
} }
std::cerr << "Vocab Size " << vocab_size << std::endl; _dict_size = dict_size;
/* Simplify FST /* Simplify FST
......
...@@ -40,31 +40,32 @@ public: ...@@ -40,31 +40,32 @@ public:
*/ */
class Scorer { class Scorer {
public: public:
Scorer(double alpha, double beta, const std::string &lm_path); Scorer(double alpha,
double beta,
const std::string &lm_path,
const std::vector<std::string> &vocabulary);
~Scorer(); ~Scorer();
double get_log_cond_prob(const std::vector<std::string> &words); double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string> &words); double get_sent_log_prob(const std::vector<std::string> &words);
size_t get_max_order() { return _max_order; } size_t get_max_order() const { return _max_order; }
bool is_char_map_empty() { return _char_map.size() == 0; } size_t get_dict_size() const { return _dict_size; }
bool is_character_based() { return _is_character_based; } bool is_char_map_empty() const { return _char_map.size() == 0; }
bool is_character_based() const { return _is_character_based; }
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// make ngram // make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix); std::vector<std::string> make_ngram(PathTrie *prefix);
// fill dictionary for fst // trransform the labels in index to the vector of words (word based lm) or
void fill_dictionary(bool add_space); // the vector of characters (character based lm)
// set char map
void set_char_map(const std::vector<std::string> &char_list);
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder // expose to decoder
...@@ -75,7 +76,16 @@ public: ...@@ -75,7 +76,16 @@ public:
void *dictionary; void *dictionary;
protected: protected:
void load_LM(const char *filename); void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
void load_lm(const std::string &lm_path);
// fill dictionary for fst
void fill_dictionary(bool add_space);
// set char map
void set_char_map(const std::vector<std::string> &char_list);
double get_log_prob(const std::vector<std::string> &words); double get_log_prob(const std::vector<std::string> &words);
...@@ -85,6 +95,7 @@ private: ...@@ -85,6 +95,7 @@ private:
void *_language_model; void *_language_model;
bool _is_character_based; bool _is_character_based;
size_t _max_order; size_t _max_order;
size_t _dict_size;
int _SPACE_ID; int _SPACE_ID;
std::vector<std::string> _char_list; std::vector<std::string> _char_list;
......
...@@ -70,8 +70,11 @@ FILES = glob.glob('kenlm/util/*.cc') \ ...@@ -70,8 +70,11 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# FILES + glob.glob('glog/src/*.cc')
FILES = [ FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
] ]
LIBS = ['stdc++'] LIBS = ['stdc++']
...@@ -99,7 +102,13 @@ decoders_module = [ ...@@ -99,7 +102,13 @@ decoders_module = [
name='_swig_decoders', name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='c++', language='c++',
include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'], include_dirs=[
'.',
'kenlm',
'openfst-1.6.3/src/include',
'ThreadPool',
#'glog/src'
],
libraries=LIBS, libraries=LIBS,
extra_compile_args=ARGS) extra_compile_args=ARGS)
] ]
......
#!/bin/bash #!/usr/bin/env bash
if [ ! -d kenlm ]; then if [ ! -d kenlm ]; then
git clone https://github.com/luotao1/kenlm.git git clone https://github.com/luotao1/kenlm.git
......
...@@ -13,14 +13,14 @@ class Scorer(swig_decoders.Scorer): ...@@ -13,14 +13,14 @@ class Scorer(swig_decoders.Scorer):
language model when alpha = 0. language model when alpha = 0.
:type alpha: float :type alpha: float
:param beta: Parameter associated with word count. Don't use word :param beta: Parameter associated with word count. Don't use word
count when beta = 0. count when beta = 0.
:type beta: float :type beta: float
:model_path: Path to load language model. :model_path: Path to load language model.
:type model_path: basestring :type model_path: basestring
""" """
def __init__(self, alpha, beta, model_path): def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path) swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_greedy_decoder(probs_seq, vocabulary):
...@@ -58,12 +58,12 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -58,12 +58,12 @@ def ctc_beam_search_decoder(probs_seq,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be characters with highest probs in vocabulary will be
used in beam search, default 40. used in beam search, default 40.
:type cutoff_top_n: int :type cutoff_top_n: int
:param ext_scoring_func: External scoring function for :param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
or language model. or language model.
:type external_scoring_func: callable :type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding :return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability. results, in descending order of the probability.
...@@ -96,14 +96,14 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -96,14 +96,14 @@ def ctc_beam_search_decoder_batch(probs_split,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be characters with highest probs in vocabulary will be
used in beam search, default 40. used in beam search, default 40.
:type cutoff_top_n: int :type cutoff_top_n: int
:param num_processes: Number of parallel processes. :param num_processes: Number of parallel processes.
:type num_processes: int :type num_processes: int
:param ext_scoring_func: External scoring function for :param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
or language model. or language model.
:type external_scoring_function: callable :type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding :return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability. results, in descending order of the probability.
......
...@@ -21,9 +21,9 @@ python -u infer.py \ ...@@ -21,9 +21,9 @@ python -u infer.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -30,9 +30,9 @@ python -u infer.py \ ...@@ -30,9 +30,9 @@ python -u infer.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -22,9 +22,9 @@ python -u test.py \ ...@@ -22,9 +22,9 @@ python -u test.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -31,9 +31,9 @@ python -u test.py \ ...@@ -31,9 +31,9 @@ python -u test.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
......
...@@ -112,6 +112,7 @@ def infer(): ...@@ -112,6 +112,7 @@ def infer():
print("Current error rate [%s] = %f" % print("Current error rate [%s] = %f" %
(args.error_rate_type, error_rate_func(target, result))) (args.error_rate_type, error_rate_func(target, result)))
ds2_model.logger.info("finish inference")
def main(): def main():
print_arguments(args) print_arguments(args)
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import sys import sys
import os import os
import time import time
import logging
import gzip import gzip
import paddle.v2 as paddle import paddle.v2 as paddle
from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import Scorer
...@@ -13,6 +14,9 @@ from decoders.swig_wrapper import ctc_greedy_decoder ...@@ -13,6 +14,9 @@ from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.network import deep_speech_v2_network from model_utils.network import deep_speech_v2_network
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
class DeepSpeech2Model(object): class DeepSpeech2Model(object):
"""DeepSpeech2Model class. """DeepSpeech2Model class.
...@@ -43,6 +47,8 @@ class DeepSpeech2Model(object): ...@@ -43,6 +47,8 @@ class DeepSpeech2Model(object):
self._inferer = None self._inferer = None
self._loss_inferer = None self._loss_inferer = None
self._ext_scorer = None self._ext_scorer = None
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)
def train(self, def train(self,
train_batch_reader, train_batch_reader,
...@@ -204,16 +210,25 @@ class DeepSpeech2Model(object): ...@@ -204,16 +210,25 @@ class DeepSpeech2Model(object):
elif decoding_method == "ctc_beam_search": elif decoding_method == "ctc_beam_search":
# initialize external scorer # initialize external scorer
if self._ext_scorer == None: if self._ext_scorer == None:
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path)
self._loaded_lm_path = language_model_path self._loaded_lm_path = language_model_path
self._ext_scorer.set_char_map(vocab_list) self.logger.info("begin to initialize the external scorer "
if (not self._ext_scorer.is_character_based()): "for decoding")
self._ext_scorer.fill_dictionary(True) self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer. Start decoding ...")
else: else:
self._ext_scorer.reset_params(beam_alpha, beam_beta) self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path assert self._loaded_lm_path == language_model_path
# beam search decode # beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch( beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
......
...@@ -115,6 +115,7 @@ def evaluate(): ...@@ -115,6 +115,7 @@ def evaluate():
print("Final error rate [%s] (%d/%d) = %f" % print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) (args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
ds2_model.logger.info("finish evaluation")
def main(): def main():
print_arguments(args) print_arguments(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册