diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 71ba2434f2875c007b011cef659a031c7d80a621..70ee6fbad976460cf2dcf191a9e98e324bd32d13 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -55,6 +55,10 @@ class DataGenerator(object): :type num_threads: int :param random_seed: Random seed. :type random_seed: int + :param keep_transcription_text: If set to True, transcription text will + be passed forward directly without + converting to index sequence. + :type keep_transcription_text: bool """ def __init__(self, @@ -69,7 +73,8 @@ class DataGenerator(object): specgram_type='linear', use_dB_normalization=True, num_threads=multiprocessing.cpu_count() // 2, - random_seed=0): + random_seed=0, + keep_transcription_text=False): self._max_duration = max_duration self._min_duration = min_duration self._normalizer = FeatureNormalizer(mean_std_filepath) @@ -84,6 +89,7 @@ class DataGenerator(object): use_dB_normalization=use_dB_normalization) self._num_threads = num_threads self._rng = random.Random(random_seed) + self._keep_transcription_text = keep_transcription_text self._epoch = 0 # for caching tar files info self._local_data = local() @@ -97,8 +103,8 @@ class DataGenerator(object): :type filename: basestring | file :param transcript: Transcription text. :type transcript: basestring - :return: Tuple of audio feature tensor and list of token ids for - transcription. + :return: Tuple of audio feature tensor and data of transcription part, + where transcription part could be token ids or text. :rtype: tuple of (2darray, list) """ if filename.startswith('tar:'): @@ -107,9 +113,10 @@ class DataGenerator(object): else: speech_segment = SpeechSegment.from_file(filename, transcript) self._augmentation_pipeline.transform_audio(speech_segment) - specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram, transcript_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) specgram = self._normalizer.apply(specgram) - return specgram, text_ids + return specgram, transcript_part def batch_reader_creator(self, manifest_path, diff --git a/deep_speech_2/data_utils/featurizer/speech_featurizer.py b/deep_speech_2/data_utils/featurizer/speech_featurizer.py index a947588db4a29d7d49b9650c2da28731259cc0e0..4555dc31da89367b4775b712d3876168aae268f4 100644 --- a/deep_speech_2/data_utils/featurizer/speech_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/speech_featurizer.py @@ -60,12 +60,12 @@ class SpeechFeaturizer(object): target_dB=target_dB) self._text_featurizer = TextFeaturizer(vocab_filepath) - def featurize(self, speech_segment): + def featurize(self, speech_segment, keep_transcription_text): """Extract features for speech segment. 1. For audio parts, extract the audio features. - 2. For transcript parts, convert text string to a list of token indices - in char-level. + 2. For transcript parts, keep the original text or convert text string + to a list of token indices in char-level. :param audio_segment: Speech segment to extract features from. :type audio_segment: SpeechSegment @@ -74,6 +74,8 @@ class SpeechFeaturizer(object): :rtype: tuple """ audio_feature = self._audio_featurizer.featurize(speech_segment) + if keep_transcription_text: + return audio_feature, speech_segment.transcript text_ids = self._text_featurizer.featurize(speech_segment.transcript) return audio_feature, text_ids diff --git a/deep_speech_2/deploy/demo_server.py b/deep_speech_2/deploy/demo_server.py index b007c751e730726b4a7f08b9ac40684a54ec7e04..3e81c0c5b92a873cc5854c727429c75a3ae1fb0d 100644 --- a/deep_speech_2/deploy/demo_server.py +++ b/deep_speech_2/deploy/demo_server.py @@ -146,7 +146,8 @@ def start_server(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=1) + num_threads=1, + keep_transcription_text=True) # prepare ASR model ds2_model = DeepSpeech2Model( vocab_size=data_generator.vocab_size, diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index a30d48d6de567f62261a4adae55820321d8ca156..9ac3e632efc792b629fa670943c3593037f186b9 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -68,7 +68,8 @@ def infer(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=1) + num_threads=1, + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.infer_manifest, batch_size=args.num_samples, @@ -102,10 +103,7 @@ def infer(): num_processes=args.num_proc_bsearch) error_rate_func = cer if args.error_rate_type == 'cer' else wer - target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data - ] + target_transcripts = [transcript for _, transcript in infer_data] for target, result in zip(target_transcripts, result_transcripts): print("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) diff --git a/deep_speech_2/test.py b/deep_speech_2/test.py index 94c09150ca5b3f54db11ef4e27f798cf469510cf..63fc4f65c9254969007eff372fcee9c8bb87c621 100644 --- a/deep_speech_2/test.py +++ b/deep_speech_2/test.py @@ -69,7 +69,8 @@ def evaluate(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_proc_data) + num_threads=args.num_proc_data, + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.test_manifest, batch_size=args.batch_size, @@ -103,10 +104,7 @@ def evaluate(): vocab_list=vocab_list, language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) - target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data - ] + target_transcripts = [transcript for _, transcript in infer_data] for target, result in zip(target_transcripts, result_transcripts): error_sum += error_rate_func(target, result) num_ins += 1 diff --git a/deep_speech_2/tools/tune.py b/deep_speech_2/tools/tune.py index 233ec4ab84a88f3805cc532e40ee0417384ce8a6..966029a8259369a63582c059dd4072967a57212f 100644 --- a/deep_speech_2/tools/tune.py +++ b/deep_speech_2/tools/tune.py @@ -87,7 +87,8 @@ def tune(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_proc_data) + num_threads=args.num_proc_data, + keep_transcription_text=True) audio_data = paddle.layer.data( name="audio_spectrogram", @@ -163,10 +164,7 @@ def tune(): for i in xrange(len(infer_data)) ] - target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data - ] + target_transcripts = [transcript for _, transcript in infer_data] num_ins += len(target_transcripts) # grid search