diff --git a/deepspeech/decoders/swig/scorer.cpp b/deepspeech/decoders/swig/scorer.cpp index a25382b15cd73e2743d28b8e3e93d167d976fe2e..ebb9e448d6bc457cd6a51cb47aa7f7ad19694e36 100644 --- a/deepspeech/decoders/swig/scorer.cpp +++ b/deepspeech/decoders/swig/scorer.cpp @@ -26,6 +26,7 @@ #include "decoder_utils.h" using namespace lm::ngram; +const std::string kSPACE = ""; Scorer::Scorer(double alpha, double beta, @@ -165,7 +166,7 @@ void Scorer::set_char_map(const std::vector& char_list) { // Set the char map for the FST for spelling correction for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == " ") { + if (char_list_[i] == kSPACE) { SPACE_ID_ = i; } // The initial state of FST is state 0, hence the index of chars in diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 79a676345fdcb7544bec1511861f5e7accc17928..702a05760ca912be449250a230f88d3833a4edf0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -27,6 +27,7 @@ from paddle import inference from paddle.io import DataLoader from yacs.config import CfgNode +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.sampler import SortagradBatchSampler @@ -271,6 +272,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def __init__(self, config, args): super().__init__(config, args) + self._text_featurizer = TextFeaturizer( + unit_type=config.collator.unit_type, vocab_filepath=None) def ordid2token(self, texts, texts_len): """ ord() id to chr() chr """ @@ -299,6 +302,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): result_transcripts = self.compute_result_transcripts(audio, audio_len, vocab_list, cfg) + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) @@ -335,6 +339,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) + #replace the with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] + self.autolog.times.stamp() self.autolog.times.stamp() self.autolog.times.end() @@ -455,6 +465,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_probs, output_lens, vocab_list, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) + #replace the with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] return result_transcripts