提交 d2c52dbc 编写于 作者: Y yangyaming

Give option to disable converting from transcription text to ids.

上级 a9a43991
......@@ -55,6 +55,10 @@ class DataGenerator(object):
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
"""
def __init__(self,
......@@ -69,7 +73,8 @@ class DataGenerator(object):
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0):
random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
......@@ -84,6 +89,7 @@ class DataGenerator(object):
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0
# for caching tar files info
self._local_data = local()
......@@ -107,9 +113,10 @@ class DataGenerator(object):
else:
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
return specgram, transcript_part
def batch_reader_creator(self,
manifest_path,
......
......@@ -60,12 +60,12 @@ class SpeechFeaturizer(object):
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices
in char-level.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
......@@ -74,6 +74,8 @@ class SpeechFeaturizer(object):
:rtype: tuple
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids
......
......@@ -146,7 +146,8 @@ def start_server():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
......
......@@ -68,7 +68,8 @@ def infer():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
......@@ -103,8 +104,7 @@ def infer():
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
transcript for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
......
......@@ -69,7 +69,8 @@ def evaluate():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
......@@ -104,8 +105,7 @@ def evaluate():
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
transcript for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
......
......@@ -87,7 +87,8 @@ def tune():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)
audio_data = paddle.layer.data(
name="audio_spectrogram",
......@@ -164,8 +165,7 @@ def tune():
]
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
transcript for _, transcript in infer_data
]
num_ins += len(target_transcripts)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册