未验证 提交 b37ece04 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #396 from pkuyym/fix-393

Give option to disable converting from transcription text to ids.
......@@ -55,6 +55,10 @@ class DataGenerator(object):
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
"""
def __init__(self,
......@@ -69,7 +73,8 @@ class DataGenerator(object):
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0):
random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
......@@ -84,6 +89,7 @@ class DataGenerator(object):
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0
# for caching tar files info
self._local_data = local()
......@@ -97,8 +103,8 @@ class DataGenerator(object):
:type filename: basestring | file
:param transcript: Transcription text.
:type transcript: basestring
:return: Tuple of audio feature tensor and list of token ids for
transcription.
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if filename.startswith('tar:'):
......@@ -107,9 +113,10 @@ class DataGenerator(object):
else:
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
return specgram, transcript_part
def batch_reader_creator(self,
manifest_path,
......
......@@ -60,12 +60,12 @@ class SpeechFeaturizer(object):
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices
in char-level.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
......@@ -74,6 +74,8 @@ class SpeechFeaturizer(object):
:rtype: tuple
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids
......
......@@ -146,7 +146,8 @@ def start_server():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
......
......@@ -68,7 +68,8 @@ def infer():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
......@@ -102,10 +103,7 @@ def infer():
num_processes=args.num_proc_bsearch)
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
......
......@@ -69,7 +69,8 @@ def evaluate():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
......@@ -103,10 +104,7 @@ def evaluate():
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
num_ins += 1
......
......@@ -87,7 +87,8 @@ def tune():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)
audio_data = paddle.layer.data(
name="audio_spectrogram",
......@@ -163,10 +164,7 @@ def tune():
for i in xrange(len(infer_data))
]
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]
num_ins += len(target_transcripts)
# grid search
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册