提交 93ae5999 编写于 作者: H Haoxin Ma

add resampler and apply

上级 89a00eab
...@@ -93,7 +93,7 @@ class AugmentationPipeline(): ...@@ -93,7 +93,7 @@ class AugmentationPipeline():
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
augmentation_config, 'feature') augmentation_config, 'feature')
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment, single=True):
"""Run the pre-processing pipeline for data augmentation. """Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation. Note that this is an in-place transformation.
...@@ -103,9 +103,9 @@ class AugmentationPipeline(): ...@@ -103,9 +103,9 @@ class AugmentationPipeline():
""" """
for augmentor, rate in zip(self._augmentors, self._rates): for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) < rate: if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment) augmentor.transform_audio(audio_segment, single)
def transform_feature(self, spec_segment): def transform_feature(self, spec_segment, single=True):
"""spectrogram augmentation. """spectrogram augmentation.
Args: Args:
...@@ -113,7 +113,7 @@ class AugmentationPipeline(): ...@@ -113,7 +113,7 @@ class AugmentationPipeline():
""" """
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates): for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
if self._rng.uniform(0., 1.) < rate: if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment) spec_segment = augmentor.transform_feature(spec_segment, single)
return spec_segment return spec_segment
def _parse_pipeline_from(self, config_json, aug_type='audio'): def _parse_pipeline_from(self, config_json, aug_type='audio'):
......
...@@ -31,7 +31,13 @@ class ShiftPerturbAugmentor(AugmentorBase): ...@@ -31,7 +31,13 @@ class ShiftPerturbAugmentor(AugmentorBase):
self._max_shift_ms = max_shift_ms self._max_shift_ms = max_shift_ms
self._rng = rng self._rng = rng
def transform_audio(self, audio_segment): def randomize_parameters(self):
self.shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
def apply(self, audio_segment):
audio_segment.shift(self.shift_ms)
def transform_audio(self, audio_segment, single):
"""Shift audio. """Shift audio.
Note that this is an in-place transformation. Note that this is an in-place transformation.
...@@ -39,5 +45,20 @@ class ShiftPerturbAugmentor(AugmentorBase): ...@@ -39,5 +45,20 @@ class ShiftPerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to. :param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms) if(single):
audio_segment.shift(shift_ms) self.randomize_parameters()
self.apply(audio_segment)
# def transform_audio(self, audio_segment):
# """Shift audio.
# Note that this is an in-place transformation.
# :param audio_segment: Audio segment to add effects to.
# :type audio_segment: AudioSegmenet|SpeechSegment
# """
# shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
# audio_segment.shift(shift_ms)
...@@ -124,6 +124,54 @@ class SpecAugmentor(AugmentorBase): ...@@ -124,6 +124,54 @@ class SpecAugmentor(AugmentorBase):
def time_warp(xs, W=40): def time_warp(xs, W=40):
raise NotImplementedError raise NotImplementedError
def randomize_parameters(self, xs):
n_bins = xs.shape[0]
n_frames = xs.shape[1]
self.f=[]
self.f_0=[]
self.t=[]
self.t_0=[]
for i in range(0, self.n_freq_masks):
f=int(self._rng.uniform(low=0, high=self.F))
self.f.append(f)
self.f_0.append(int(self._rng.uniform(low=0, high=n_bins - f)))
if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio)
self.n_masks = min(n_masks, self.max_n_time_masks)
else:
self.n_masks = self.n_time_masks
if self.adaptive_size_ratio > 0:
T = self.adaptive_size_ratio * n_frames
else:
T = self.T
for i in range(self.n_masks):
t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p))
self.t.append(t)
self.t_0.append(int(self._rng.uniform(low=0, high=n_frames - t)))
def apply(self, xs: np.ndarray):
n_bins = xs.shape[0]
n_frames = xs.shape[1]
for i in range(0, self.n_freq_masks):
f = self.f[i]
f_0 = self.f_0[i]
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f
for i in range(self.n_masks):
t = self.t[i]
t_0 = self.t_0[i]
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t
return xs
def mask_freq(self, xs, replace_with_zero=False): def mask_freq(self, xs, replace_with_zero=False):
n_bins = xs.shape[0] n_bins = xs.shape[0]
for i in range(0, self.n_freq_masks): for i in range(0, self.n_freq_masks):
...@@ -157,14 +205,26 @@ class SpecAugmentor(AugmentorBase): ...@@ -157,14 +205,26 @@ class SpecAugmentor(AugmentorBase):
self._time_mask = (t_0, t_0 + t) self._time_mask = (t_0, t_0 + t)
return xs return xs
def transform_feature(self, xs: np.ndarray):
def transform_feature(self, xs: np.ndarray, single=True):
""" """
Args: Args:
xs (FloatTensor): `[F, T]` xs (FloatTensor): `[F, T]`
Returns: Returns:
xs (FloatTensor): `[F, T]` xs (FloatTensor): `[F, T]`
""" """
# xs = self.time_warp(xs) if(single):
xs = self.mask_freq(xs) self.randomize_parameters(xs)
xs = self.mask_time(xs) return self.apply(xs)
return xs
# def transform_feature(self, xs: np.ndarray):
# """
# Args:
# xs (FloatTensor): `[F, T]`
# Returns:
# xs (FloatTensor): `[F, T]`
# """
# # xs = self.time_warp(xs)
# xs = self.mask_freq(xs)
# xs = self.mask_time(xs)
# return xs
...@@ -79,7 +79,21 @@ class SpeedPerturbAugmentor(AugmentorBase): ...@@ -79,7 +79,21 @@ class SpeedPerturbAugmentor(AugmentorBase):
self._rates = np.linspace( self._rates = np.linspace(
self._min_rate, self._max_rate, self._num_rates, endpoint=True) self._min_rate, self._max_rate, self._num_rates, endpoint=True)
def transform_audio(self, audio_segment):
def randomize_parameters(self):
if self._num_rates < 0:
self.speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
else:
self.speed_rate = self._rng.choice(self._rates)
def apply(self, audio_segment):
# Skip perturbation in case of identity speed rate
if speed_rate == 1.0:
return
audio_segment.change_speed(speed_rate)
def transform_audio(self, audio_segment,single=True):
"""Sample a new speed rate from the given range and """Sample a new speed rate from the given range and
changes the speed of the given audio clip. changes the speed of the given audio clip.
...@@ -88,13 +102,26 @@ class SpeedPerturbAugmentor(AugmentorBase): ...@@ -88,13 +102,26 @@ class SpeedPerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to. :param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment :type audio_segment: AudioSegment|SpeechSegment
""" """
if self._num_rates < 0: if(single):
speed_rate = self._rng.uniform(self._min_rate, self._max_rate) self.randomize_parameters()
else: self.apply(audio_segment)
speed_rate = self._rng.choice(self._rates)
# Skip perturbation in case of identity speed rate # def transform_audio(self, audio_segment):
if speed_rate == 1.0: # """Sample a new speed rate from the given range and
return # changes the speed of the given audio clip.
audio_segment.change_speed(speed_rate) # Note that this is an in-place transformation.
# :param audio_segment: Audio segment to add effects to.
# :type audio_segment: AudioSegment|SpeechSegment
# """
# if self._num_rates < 0:
# speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
# else:
# speed_rate = self._rng.choice(self._rates)
# # Skip perturbation in case of identity speed rate
# if speed_rate == 1.0:
# return
# audio_segment.change_speed(speed_rate)
...@@ -192,7 +192,7 @@ class SpeechCollator(): ...@@ -192,7 +192,7 @@ class SpeechCollator():
return self._local_data.tar2object[tarpath].extractfile( return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename]) self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript): def process_utterance(self, audio_file, transcript, single=True):
"""Load, augment, featurize and normalize for speech data. """Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file. :param audio_file: Filepath or file object of audio file.
...@@ -214,7 +214,7 @@ class SpeechCollator(): ...@@ -214,7 +214,7 @@ class SpeechCollator():
# audio augment # audio augment
start_time = time.time() start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment, single)
audio_aug_time = time.time() - start_time audio_aug_time = time.time() - start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}") #logger.debug(f"audio augmentation time: {audio_aug_time}")
...@@ -228,7 +228,7 @@ class SpeechCollator(): ...@@ -228,7 +228,7 @@ class SpeechCollator():
# specgram augment # specgram augment
start_time = time.time() start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram, single)
feature_aug_time = time.time() - start_time feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}") #logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part return specgram, transcript_part
...@@ -253,8 +253,14 @@ class SpeechCollator(): ...@@ -253,8 +253,14 @@ class SpeechCollator():
texts = [] texts = []
text_lens = [] text_lens = []
utts = [] utts = []
# print('----debug---')
# print(batch)
# print(type(batch))
# print(len(batch))
resample=True
for utt, audio, text in batch: for utt, audio, text in batch:
audio, text = self.process_utterance(audio, text) audio, text = self.process_utterance(audio, text, single=resample)
# resample=False
#utt #utt
utts.append(utt) utts.append(utt)
# audio # audio
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册