diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 0caa617eb101cd0806d60c5e52958309de0ebc2d..543af0ad108acce896a944ff0da3775262d9c886 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -498,13 +498,13 @@ Language Model | Training Data | Token-based | Size | Descriptions Test Set | LibriSpeech Model | BaiduEN8K Model :--------------------- | ---------------: | -------------------: -LibriSpeech Test-Clean | 8.06 | 6.63 -LibriSpeech Test-Other | 24.25 | 16.59 -VoxForge American-Canadian | 13.48 |   7.46 -VoxForge Commonwealth | 22.37 | 16.23 -VoxForge European | 32.64 | 20.47 -VoxForge Indian | 58.48 | 28.15 -Baidu Internal Testset  |   48.93 |   8.92 +LibriSpeech Test-Clean | 7.77 | 6.63 +LibriSpeech Test-Other | 23.25 | 16.59 +VoxForge American-Canadian | 12.52 |   7.46 +VoxForge Commonwealth | 21.08 | 16.23 +VoxForge European | 31.21 | 20.47 +VoxForge Indian | 56.79 | 28.15 +Baidu Internal Testset  |   47.73 |   8.92 #### Benchmark Results for Mandarin Model (Character Error Rate) diff --git a/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp index 624784b05e215782f2264cc6ae4db7eed5b28cae..4a63af26af5e2da135f75386581fbe61f56c7fc7 100644 --- a/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp @@ -110,17 +110,17 @@ std::vector> ctc_beam_search_decoder( // language model scoring if (ext_scorer != nullptr && (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_toscore = nullptr; + PathTrie *prefix_to_score = nullptr; // skip scoring the space if (ext_scorer->is_character_based()) { - prefix_toscore = prefix_new; + prefix_to_score = prefix_new; } else { - prefix_toscore = prefix; + prefix_to_score = prefix; } - double score = 0.0; + float score = 0.0; std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_toscore); + ngram = ext_scorer->make_ngram(prefix_to_score); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; log_p += score; log_p += ext_scorer->beta; @@ -131,6 +131,7 @@ std::vector> ctc_beam_search_decoder( } // end of loop over prefix } // end of loop over vocabulary + prefixes.clear(); // update log probs root.iterate_to_vec(prefixes); @@ -147,6 +148,23 @@ std::vector> ctc_beam_search_decoder( } } // end of loop over time + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {