提交 043127b6 编写于 作者: H Haoxin Ma

revise collator

上级 1ec93dbd
......@@ -104,7 +104,7 @@ class AugmentationPipeline():
for augmentor, rate in zip(self._augmentors, self._rates):
augmentor.randomize_parameters()
def randomize_parameters_feature_transform(self, audio):
def randomize_parameters_feature_transform(self, n_frames, n_bins):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
......@@ -112,8 +112,8 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._augmentors, self._rates):
augmentor.randomize_parameters(audio)
for augmentor, rate in zip(self._spec_augmentors, self._rates):
augmentor.randomize_parameters(n_frames, n_bins)
def apply_audio_transform(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
......
......@@ -37,17 +37,17 @@ class ShiftPerturbAugmentor(AugmentorBase):
def apply(self, audio_segment):
audio_segment.shift(self.shift_ms)
def transform_audio(self, audio_segment, single):
"""Shift audio.
# def transform_audio(self, audio_segment, single):
# """Shift audio.
Note that this is an in-place transformation.
# Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
if(single):
self.randomize_parameters()
self.apply(audio_segment)
# :param audio_segment: Audio segment to add effects to.
# :type audio_segment: AudioSegmenet|SpeechSegment
# """
# if(single):
# self.randomize_parameters()
# self.apply(audio_segment)
# def transform_audio(self, audio_segment):
......
......@@ -124,7 +124,7 @@ class SpecAugmentor(AugmentorBase):
def time_warp(xs, W=40):
raise NotImplementedError
def randomize_parameters(self, n_frame, n_bins):
def randomize_parameters(self, n_frames, n_bins):
# n_bins = xs.shape[0]
# n_frames = xs.shape[1]
......
......@@ -110,7 +110,8 @@ class SpeechCollator():
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
keep_transcription_text=config.collator.keep_transcription_text,
randomize_each_batch=config.collator.randomize_each_batch)
return speech_collator
def __init__(
......@@ -132,7 +133,8 @@ class SpeechCollator():
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
keep_transcription_text=True,
randomize_each_batch=False):
"""SpeechCollator Collator
Args:
......@@ -160,6 +162,7 @@ class SpeechCollator():
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._randomize_each_batch = randomize_each_batch
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
......@@ -171,6 +174,7 @@ class SpeechCollator():
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
......@@ -224,10 +228,10 @@ class SpeechCollator():
return speech_segment
def randomize_audio_parameters(self):
self._augmentation_pipeline.andomize_parameters_audio_transform()
self._augmentation_pipeline.randomize_parameters_audio_transform()
def randomize_feature_parameters(self, n_bins, n_frames):
self._augmentation_pipeline.andomize_parameters_feature_transform(n_bins, n_frames)
def randomize_feature_parameters(self, n_frames, n_bins):
self._augmentation_pipeline.randomize_parameters_feature_transform(n_frames, n_bins)
def process_feature_and_transform(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
......@@ -317,12 +321,15 @@ class SpeechCollator():
# print(len(batch))
self.randomize_audio_parameters()
for utt, audio, text in batch:
if not self.config.randomize_each_batch:
if not self._randomize_each_batch:
self.randomize_audio_parameters()
audio, text = self.process_feature_and_transform(audio, text)
#utt
utts.append(utt)
# audio
# print("---debug---")
# print(audio.shape)
audio=audio.T
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
......@@ -350,7 +357,7 @@ class SpeechCollator():
n_bins=padded_audios.shape[2]
self.randomize_feature_parameters(min(audio_lens), n_bins)
for i in range(len(padded_audios)):
if not self.config.randomize_each_batch:
if not self._randomize_each_batch:
self.randomize_feature_parameters(n_bins, audio_lens[i])
padded_audios[i] = self._augmentation_pipeline.apply_feature_transform(padded_audios[i])
......
......@@ -11,7 +11,8 @@ data:
max_output_input_ratio: .inf
collator:
batch_size: 64 # one gpu
batch_size: 32 #64 # one gpu
randomize_each_batch: False
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册