提交 a18e6a7e 编写于 作者: Y Yibing Liu

refine by following review comments

上级 e0ab51f4
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
## Installation ## Installation
### Basic setup
Please make sure the above [prerequisites](#prerequisites) have been satisfied before moving on. Please make sure the above [prerequisites](#prerequisites) have been satisfied before moving on.
```bash ```bash
...@@ -34,16 +32,6 @@ cd models/deep_speech_2 ...@@ -34,16 +32,6 @@ cd models/deep_speech_2
sh setup.sh sh setup.sh
``` ```
### Decoders setup
```bash
cd decoders/swig
sh setup.sh
cd ../..
```
These commands will install the decoders that translate the ouptut probability vectors of DS2 model to text data, incuding CTC greedy decoder, CTC beam search decoder and its batch version. And a detailed usuage about them will be given in the following sections.
## Getting Started ## Getting Started
Several shell scripts provided in `./examples` will help us to quickly give it a try, for most major modules, including data preparation, model training, case inference and model evaluation, with a few public dataset (e.g. [LibriSpeech](http://www.openslr.org/12/), [Aishell](http://www.openslr.org/33)). Reading these examples will also help you to understand how to make it work with your own data. Several shell scripts provided in `./examples` will help us to quickly give it a try, for most major modules, including data preparation, model training, case inference and model evaluation, with a few public dataset (e.g. [LibriSpeech](http://www.openslr.org/12/), [Aishell](http://www.openslr.org/33)). Reading these examples will also help you to understand how to make it work with your own data.
...@@ -189,7 +177,6 @@ Data augmentation has often been a highly effective technique to boost the deep ...@@ -189,7 +177,6 @@ Data augmentation has often been a highly effective technique to boost the deep
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.
### Inference ### Inference
- Volume Perturbation - Volume Perturbation
- Speed Perturbation - Speed Perturbation
- Shifting Perturbation - Shifting Perturbation
......
...@@ -22,6 +22,8 @@ class TextFeaturizer(object): ...@@ -22,6 +22,8 @@ class TextFeaturizer(object):
def __init__(self, vocab_filepath): def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath)
# from unicode to string
self._vocab_list = [chars.encode("utf-8") for chars in self._vocab_list]
def featurize(self, text): def featurize(self, text):
"""Convert text string to a list of token indices in char-level.Note """Convert text string to a list of token indices in char-level.Note
......
...@@ -17,41 +17,38 @@ std::string ctc_greedy_decoder( ...@@ -17,41 +17,38 @@ std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary) {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size() + 1) { VALID_CHECK_EQ(probs_seq[i].size(),
std::cout << "The shape of probs_seq does not match" vocabulary.size() + 1,
<< " with the shape of the vocabulary!" << std::endl; "The shape of probs_seq does not match with "
exit(1); "the shape of the vocabulary");
}
} }
int blank_id = vocabulary.size(); size_t blank_id = vocabulary.size();
std::vector<int> max_idx_vec; std::vector<size_t> max_idx_vec;
for (size_t i = 0; i < num_time_steps; i++) {
double max_prob = 0.0; double max_prob = 0.0;
int max_idx = 0; size_t max_idx = 0;
for (int i = 0; i < num_time_steps; i++) { for (size_t j = 0; j < probs_seq[i].size(); j++) {
for (int j = 0; j < probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) { if (max_prob < probs_seq[i][j]) {
max_idx = j; max_idx = j;
max_prob = probs_seq[i][j]; max_prob = probs_seq[i][j];
} }
} }
max_idx_vec.push_back(max_idx); max_idx_vec.push_back(max_idx);
max_prob = 0.0;
max_idx = 0;
} }
std::vector<int> idx_vec; std::vector<size_t> idx_vec;
for (int i = 0; i < max_idx_vec.size(); i++) { for (size_t i = 0; i < max_idx_vec.size(); i++) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
std::string best_path_result; std::string best_path_result;
for (int i = 0; i < idx_vec.size(); i++) { for (size_t i = 0; i < idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]]; best_path_result += vocabulary[idx_vec[i]];
} }
...@@ -61,29 +58,24 @@ std::string ctc_greedy_decoder( ...@@ -61,29 +58,24 @@ std::string ctc_greedy_decoder(
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
int beam_size, const size_t beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, const double cutoff_prob,
double cutoff_prob, const size_t cutoff_top_n,
int cutoff_top_n, Scorer *ext_scorer) {
Scorer *extscorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size() + 1) { VALID_CHECK_EQ(probs_seq[i].size(),
std::cout << " The shape of probs_seq does not match" vocabulary.size() + 1,
<< " with the shape of the vocabulary!" << std::endl; "The shape of probs_seq does not match with "
exit(1); "the shape of the vocabulary");
}
} }
// blank_id check // assign blank id
if (blank_id > vocabulary.size()) { size_t blank_id = vocabulary.size();
std::cout << " Invalid blank_id! " << std::endl;
exit(1);
}
// assign space ID // assign space id
std::vector<std::string>::iterator it = std::vector<std::string>::iterator it =
std::find(vocabulary.begin(), vocabulary.end(), " "); std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
...@@ -98,16 +90,16 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -98,16 +90,16 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<PathTrie *> prefixes; std::vector<PathTrie *> prefixes;
prefixes.push_back(&root); prefixes.push_back(&root);
if (extscorer != nullptr) { if (ext_scorer != nullptr) {
if (extscorer->is_char_map_empty()) { if (ext_scorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary); ext_scorer->set_char_map(vocabulary);
} }
if (!extscorer->is_character_based()) { if (!ext_scorer->is_character_based()) {
if (extscorer->dictionary == nullptr) { if (ext_scorer->dictionary == nullptr) {
// fill dictionary for fst with space // fill dictionary for fst with space
extscorer->fill_dictionary(true); ext_scorer->fill_dictionary(true);
} }
auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary); auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr); root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
...@@ -116,33 +108,33 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -116,33 +108,33 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
// prefix search over time // prefix search over time
for (int time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step]; std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double>> prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (int i = 0; i < prob.size(); i++) { for (size_t i = 0; i < prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i])); prob_idx.push_back(std::pair<int, double>(i, prob[i]));
} }
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
if (extscorer != nullptr) { if (ext_scorer != nullptr) {
int num_prefixes = std::min((int)prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) - min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) -
std::max(0.0, extscorer->beta); std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
// pruning of vacobulary // pruning of vacobulary
int cutoff_len = prob.size(); size_t cutoff_len = prob.size();
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
std::sort( std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (int i = 0; i < prob_idx.size(); i++) { for (size_t i = 0; i < prob_idx.size(); i++) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
...@@ -152,18 +144,18 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -152,18 +144,18 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prob_idx = std::vector<std::pair<int, double>>( prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len); prob_idx.begin(), prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<int, float>> log_prob_idx; std::vector<std::pair<size_t, float>> log_prob_idx;
for (int i = 0; i < cutoff_len; i++) { for (size_t i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float>( log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
} }
// loop over chars // loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) { for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second; float log_prob_c = log_prob_idx[index].second;
for (int i = 0; i < prefixes.size() && i < beam_size; i++) { for (size_t i = 0; i < prefixes.size() && i < beam_size; i++) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
...@@ -194,12 +186,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -194,12 +186,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
// language model scoring // language model scoring
if (extscorer != nullptr && if (ext_scorer != nullptr &&
(c == space_id || extscorer->is_character_based())) { (c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_toscore = nullptr; PathTrie *prefix_toscore = nullptr;
// skip scoring the space // skip scoring the space
if (extscorer->is_character_based()) { if (ext_scorer->is_character_based()) {
prefix_toscore = prefix_new; prefix_toscore = prefix_new;
} else { } else {
prefix_toscore = prefix; prefix_toscore = prefix;
...@@ -207,11 +199,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -207,11 +199,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
double score = 0.0; double score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = extscorer->make_ngram(prefix_toscore); ngram = ext_scorer->make_ngram(prefix_toscore);
score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha; score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score; log_p += score;
log_p += extscorer->beta; log_p += ext_scorer->beta;
} }
prefix_new->log_prob_nb_cur = prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
...@@ -240,15 +232,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -240,15 +232,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
double approx_ctc = prefixes[i]->score; double approx_ctc = prefixes[i]->score;
if (extscorer != nullptr) { if (ext_scorer != nullptr) {
std::vector<int> output; std::vector<int> output;
prefixes[i]->get_path_vec(output); prefixes[i]->get_path_vec(output);
size_t prefix_length = output.size(); size_t prefix_length = output.size();
auto words = extscorer->split_labels(output); auto words = ext_scorer->split_labels(output);
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * extscorer->beta; approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha; approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
...@@ -269,7 +261,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -269,7 +261,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
space_prefixes[i]->get_path_vec(output); space_prefixes[i]->get_path_vec(output);
// convert index to string // convert index to string
std::string output_str; std::string output_str;
for (int j = 0; j < output.size(); j++) { for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]]; output_str += vocabulary[output[j]];
} }
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc, std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
...@@ -283,49 +275,45 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -283,49 +275,45 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size, const size_t beam_size,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
int blank_id, const size_t num_processes,
int num_processes, const double cutoff_prob,
double cutoff_prob, const size_t cutoff_top_n,
int cutoff_top_n, Scorer *ext_scorer) {
Scorer *extscorer) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1);
}
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
// number of samples // number of samples
int batch_size = probs_split.size(); size_t batch_size = probs_split.size();
// scorer filling up // scorer filling up
if (extscorer != nullptr) { if (ext_scorer != nullptr) {
if (extscorer->is_char_map_empty()) { if (ext_scorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary); ext_scorer->set_char_map(vocabulary);
} }
if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) { if (!ext_scorer->is_character_based() &&
ext_scorer->dictionary == nullptr) {
// init dictionary // init dictionary
extscorer->fill_dictionary(true); ext_scorer->fill_dictionary(true);
} }
} }
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], probs_split[i],
beam_size, beam_size,
vocabulary, vocabulary,
blank_id,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
extscorer)); ext_scorer));
} }
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results; std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (int i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }
return batch_results; return batch_results;
......
...@@ -27,21 +27,21 @@ std::string ctc_greedy_decoder( ...@@ -27,21 +27,21 @@ std::string ctc_greedy_decoder(
* over vocabulary of one time step. * over vocabulary of one time step.
* beam_size: The width of beam search. * beam_size: The width of beam search.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix. * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return: * Return:
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
int beam_size, const size_t beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, const double cutoff_prob = 1.0,
double cutoff_prob = 1.0, const size_t cutoff_top_n = 40,
int cutoff_top_n = 40,
Scorer *ext_scorer = NULL); Scorer *ext_scorer = NULL);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
...@@ -52,11 +52,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -52,11 +52,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* . * .
* beam_size: The width of beam search. * beam_size: The width of beam search.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* num_processes: Number of threads for beam search. * num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix. * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return: * Return:
* A 2-D vector that each element is a vector of beam search decoding * A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample. * result for one audio sample.
...@@ -64,12 +65,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -64,12 +65,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size, const size_t beam_size,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
int blank_id, const size_t num_processes,
int num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
int cutoff_top_n = 40, const size_t cutoff_top_n = 40,
Scorer *ext_scorer = NULL); Scorer *ext_scorer = NULL);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
...@@ -7,6 +7,22 @@ ...@@ -7,6 +7,22 @@
const float NUM_FLT_INF = std::numeric_limits<float>::max(); const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min(); const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// check if __A == _B
#define VALID_CHECK_EQ(__A, __B, __ERR) \
if ((__A) != (__B)) { \
std::ostringstream str; \
str << (__A) << " != " << (__B) << ", "; \
throw std::runtime_error(str.str() + __ERR); \
}
// check if __A > __B
#define VALID_CHECK_GT(__A, __B, __ERR) \
if ((__A) <= (__B)) { \
std::ostringstream str; \
str << (__A) << " <= " << (__B) << ", "; \
throw std::runtime_error(str.str() + __ERR); \
}
// Function template for comparing two pairs // Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, bool pair_comp_first_rev(const std::pair<T1, T2> &a,
......
...@@ -41,7 +41,6 @@ def ctc_greedy_decoder(probs_seq, vocabulary): ...@@ -41,7 +41,6 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None):
...@@ -55,8 +54,6 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -55,8 +54,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
...@@ -72,15 +69,14 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -72,15 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability. results, in descending order of the probability.
:rtype: list :rtype: list
""" """
return swig_decoders.ctc_beam_search_decoder( return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, vocabulary, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split, def ctc_beam_search_decoder_batch(probs_split,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
...@@ -94,8 +90,6 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -94,8 +90,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes. :param num_processes: Number of parallel processes.
:type num_processes: int :type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning, :param cutoff_prob: Cutoff probability in vocabulary pruning,
...@@ -118,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -118,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split = [probs_seq.tolist() for probs_seq in probs_split] probs_split = [probs_seq.tolist() for probs_seq in probs_split]
return swig_decoders.ctc_beam_search_decoder_batch( return swig_decoders.ctc_beam_search_decoder_batch(
probs_split, beam_size, vocabulary, blank_id, num_processes, probs_split, beam_size, vocabulary, num_processes, cutoff_prob,
cutoff_prob, cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func)
...@@ -31,13 +31,13 @@ python -u test.py \ ...@@ -31,13 +31,13 @@ python -u test.py \
--num_conv_layers=2 \ --num_conv_layers=2 \
--num_rnn_layers=3 \ --num_rnn_layers=3 \
--rnn_layer_size=2048 \ --rnn_layer_size=2048 \
--alpha=0.36 \ --alpha=2.15 \
--beta=0.25 \ --beta=0.35 \
--cutoff_prob=0.99 \ --cutoff_prob=1.0 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \
--test_manifest='data/tiny/manifest.test-clean' \ --test_manifest='data/librispeech/manifest.test-clean' \
--mean_std_path='models/librispeech/mean_std.npz' \ --mean_std_path='models/librispeech/mean_std.npz' \
--vocab_path='models/librispeech/vocab.txt' \ --vocab_path='models/librispeech/vocab.txt' \
--model_path='models/librispeech/params.tar.gz' \ --model_path='models/librispeech/params.tar.gz' \
......
...@@ -21,9 +21,9 @@ add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.") ...@@ -21,9 +21,9 @@ add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.") add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
...@@ -85,7 +85,6 @@ def infer(): ...@@ -85,7 +85,6 @@ def infer():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch( result_transcripts = ds2_model.infer_batch(
infer_data=infer_data, infer_data=infer_data,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
...@@ -93,7 +92,7 @@ def infer(): ...@@ -93,7 +92,7 @@ def infer():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=vocab_list, vocab_list=data_generator.vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
......
...@@ -214,7 +214,6 @@ class DeepSpeech2Model(object): ...@@ -214,7 +214,6 @@ class DeepSpeech2Model(object):
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=beam_size, beam_size=beam_size,
blank_id=len(vocab_list),
num_processes=num_processes, num_processes=num_processes,
ext_scoring_func=self._ext_scorer, ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob) cutoff_prob=cutoff_prob)
......
...@@ -26,4 +26,13 @@ if [ $? != 0 ]; then ...@@ -26,4 +26,13 @@ if [ $? != 0 ]; then
rm libsndfile-1.0.28.tar.gz rm libsndfile-1.0.28.tar.gz
fi fi
# install decoders
python -c "import swig_decoders"
if [ $? != 0 ]; then
pushd decoders/swig > /dev/null
sh setup.sh
popd > /dev/null
fi
echo "Install all dependencies successfully." echo "Install all dependencies successfully."
...@@ -22,9 +22,9 @@ add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.") ...@@ -22,9 +22,9 @@ add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.") add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 0.36, "Coef of LM for beam search.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.25, "Coef of WC for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
...@@ -85,7 +85,6 @@ def evaluate(): ...@@ -85,7 +85,6 @@ def evaluate():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0 error_sum, num_ins = 0.0, 0
for infer_data in batch_reader(): for infer_data in batch_reader():
...@@ -96,7 +95,7 @@ def evaluate(): ...@@ -96,7 +95,7 @@ def evaluate():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=vocab_list, vocab_list=data_generator.vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
target_transcripts = [ target_transcripts = [
......
...@@ -13,7 +13,7 @@ download() { ...@@ -13,7 +13,7 @@ download() {
wget -c $URL -P `dirname "$TARGET"` wget -c $URL -P `dirname "$TARGET"`
md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'` md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'`
if [ $MD5 -ne $md5_result ]; then if [ ! $MD5 == $md5_result ]; then
echo "Fail to download the language model!" echo "Fail to download the language model!"
return 1 return 1
fi fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册