未验证 提交 b37ece04 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #396 from pkuyym/fix-393

Give option to disable converting from transcription text to ids.
...@@ -55,6 +55,10 @@ class DataGenerator(object): ...@@ -55,6 +55,10 @@ class DataGenerator(object):
:type num_threads: int :type num_threads: int
:param random_seed: Random seed. :param random_seed: Random seed.
:type random_seed: int :type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
""" """
def __init__(self, def __init__(self,
...@@ -69,7 +73,8 @@ class DataGenerator(object): ...@@ -69,7 +73,8 @@ class DataGenerator(object):
specgram_type='linear', specgram_type='linear',
use_dB_normalization=True, use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2, num_threads=multiprocessing.cpu_count() // 2,
random_seed=0): random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration self._max_duration = max_duration
self._min_duration = min_duration self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath) self._normalizer = FeatureNormalizer(mean_std_filepath)
...@@ -84,6 +89,7 @@ class DataGenerator(object): ...@@ -84,6 +89,7 @@ class DataGenerator(object):
use_dB_normalization=use_dB_normalization) use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads self._num_threads = num_threads
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0 self._epoch = 0
# for caching tar files info # for caching tar files info
self._local_data = local() self._local_data = local()
...@@ -97,8 +103,8 @@ class DataGenerator(object): ...@@ -97,8 +103,8 @@ class DataGenerator(object):
:type filename: basestring | file :type filename: basestring | file
:param transcript: Transcription text. :param transcript: Transcription text.
:type transcript: basestring :type transcript: basestring
:return: Tuple of audio feature tensor and list of token ids for :return: Tuple of audio feature tensor and data of transcription part,
transcription. where transcription part could be token ids or text.
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
if filename.startswith('tar:'): if filename.startswith('tar:'):
...@@ -107,9 +113,10 @@ class DataGenerator(object): ...@@ -107,9 +113,10 @@ class DataGenerator(object):
else: else:
speech_segment = SpeechSegment.from_file(filename, transcript) speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment) specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
return specgram, text_ids return specgram, transcript_part
def batch_reader_creator(self, def batch_reader_creator(self,
manifest_path, manifest_path,
......
...@@ -60,12 +60,12 @@ class SpeechFeaturizer(object): ...@@ -60,12 +60,12 @@ class SpeechFeaturizer(object):
target_dB=target_dB) target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath) self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment): def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment. """Extract features for speech segment.
1. For audio parts, extract the audio features. 1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices 2. For transcript parts, keep the original text or convert text string
in char-level. to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from. :param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment :type audio_segment: SpeechSegment
...@@ -74,6 +74,8 @@ class SpeechFeaturizer(object): ...@@ -74,6 +74,8 @@ class SpeechFeaturizer(object):
:rtype: tuple :rtype: tuple
""" """
audio_feature = self._audio_featurizer.featurize(speech_segment) audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript) text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids return audio_feature, text_ids
......
...@@ -146,7 +146,8 @@ def start_server(): ...@@ -146,7 +146,8 @@ def start_server():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1,
keep_transcription_text=True)
# prepare ASR model # prepare ASR model
ds2_model = DeepSpeech2Model( ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size, vocab_size=data_generator.vocab_size,
......
...@@ -68,7 +68,8 @@ def infer(): ...@@ -68,7 +68,8 @@ def infer():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest, manifest_path=args.infer_manifest,
batch_size=args.num_samples, batch_size=args.num_samples,
...@@ -102,10 +103,7 @@ def infer(): ...@@ -102,10 +103,7 @@ def infer():
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [ target_transcripts = [transcript for _, transcript in infer_data]
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" % print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result)) (target, result))
......
...@@ -69,7 +69,8 @@ def evaluate(): ...@@ -69,7 +69,8 @@ def evaluate():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=args.num_proc_data) num_threads=args.num_proc_data,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest, manifest_path=args.test_manifest,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -103,10 +104,7 @@ def evaluate(): ...@@ -103,10 +104,7 @@ def evaluate():
vocab_list=vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
target_transcripts = [ target_transcripts = [transcript for _, transcript in infer_data]
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result) error_sum += error_rate_func(target, result)
num_ins += 1 num_ins += 1
......
...@@ -87,7 +87,8 @@ def tune(): ...@@ -87,7 +87,8 @@ def tune():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=args.num_proc_data) num_threads=args.num_proc_data,
keep_transcription_text=True)
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram",
...@@ -163,10 +164,7 @@ def tune(): ...@@ -163,10 +164,7 @@ def tune():
for i in xrange(len(infer_data)) for i in xrange(len(infer_data))
] ]
target_transcripts = [ target_transcripts = [transcript for _, transcript in infer_data]
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
num_ins += len(target_transcripts) num_ins += len(target_transcripts)
# grid search # grid search
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册