diff --git a/deep_speech_2/data_utils/audio.py b/deep_speech_2/data_utils/audio.py index d55fae1efc951bf6025b2a6ba02852b1640fa10f..3891f5b923f6d73c6b87dcb90bede0183b0e081c 100644 --- a/deep_speech_2/data_utils/audio.py +++ b/deep_speech_2/data_utils/audio.py @@ -6,7 +6,7 @@ from __future__ import print_function import numpy as np import io import soundfile -import scikits.samplerate +import resampy from scipy import signal import random import copy @@ -308,7 +308,7 @@ class AudioSegment(object): prior_mean_squared = 10.**(prior_db / 10.) prior_sum_of_squares = prior_mean_squared * prior_samples cumsum_of_squares = np.cumsum(self.samples**2) - sample_count = np.arange(len(self.num_samples)) + 1 + sample_count = np.arange(self.num_samples) + 1 if startup_sample_idx > 0: cumsum_of_squares[:startup_sample_idx] = \ cumsum_of_squares[startup_sample_idx] @@ -321,21 +321,19 @@ class AudioSegment(object): gain_db = target_db - rms_estimate_db self.gain_db(gain_db) - def resample(self, target_sample_rate, quality='sinc_medium'): + def resample(self, target_sample_rate, filter='kaiser_best'): """Resample the audio to a target sample rate. Note that this is an in-place transformation. :param target_sample_rate: Target sample rate. :type target_sample_rate: int - :param quality: One of {'sinc_fastest', 'sinc_medium', 'sinc_best'}. - Sets resampling speed/quality tradeoff. - See http://www.mega-nerd.com/SRC/api_misc.html#Converters - :type quality: str + :param filter: The resampling filter to use one of {'kaiser_best', + 'kaiser_fast'}. + :type filter: str """ - resample_ratio = target_sample_rate / self._sample_rate - self._samples = scikits.samplerate.resample( - self._samples, r=resample_ratio, type=quality) + self._samples = resampy.resample( + self.samples, self.sample_rate, target_sample_rate, filter=filter) self._sample_rate = target_sample_rate def pad_silence(self, duration, sides='both'): diff --git a/deep_speech_2/data_utils/augmentor/augmentation.py b/deep_speech_2/data_utils/augmentor/augmentation.py index 0d60bbdb9cdd25b6df9177140576cb2bd6641fac..9dced47314a81f52dc0eafd6e592e240953f291d 100644 --- a/deep_speech_2/data_utils/augmentor/augmentation.py +++ b/deep_speech_2/data_utils/augmentor/augmentation.py @@ -7,6 +7,10 @@ import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor +from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor +from data_utils.augmentor.resample import ResampleAugmentor +from data_utils.augmentor.online_bayesian_normalization import \ + OnlineBayesianNormalizationAugmentor class AugmentationPipeline(object): @@ -79,5 +83,11 @@ class AugmentationPipeline(object): return VolumePerturbAugmentor(self._rng, **params) elif augmentor_type == "shift": return ShiftPerturbAugmentor(self._rng, **params) + elif augmentor_type == "speed": + return SpeedPerturbAugmentor(self._rng, **params) + elif augmentor_type == "resample": + return ResampleAugmentor(self._rng, **params) + elif augmentor_type == "bayesian_normal": + return OnlineBayesianNormalizationAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/deep_speech_2/data_utils/augmentor/online_bayesian_normalization.py b/deep_speech_2/data_utils/augmentor/online_bayesian_normalization.py new file mode 100755 index 0000000000000000000000000000000000000000..e488ac7d67833631919f88b9e660a99b363b90d0 --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/online_bayesian_normalization.py @@ -0,0 +1,48 @@ +"""Contain the online bayesian normalization augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class OnlineBayesianNormalizationAugmentor(AugmentorBase): + """Augmentation model for adding online bayesian normalization. + + :param rng: Random generator object. + :type rng: random.Random + :param target_db: Target RMS value in decibels. + :type target_db: float + :param prior_db: Prior RMS estimate in decibels. + :type prior_db: float + :param prior_samples: Prior strength in number of samples. + :type prior_samples: int + :param startup_delay: Default 0.0s. If provided, this function will + accrue statistics for the first startup_delay + seconds before applying online normalization. + :type starup_delay: float. + """ + + def __init__(self, + rng, + target_db, + prior_db, + prior_samples, + startup_delay=0.0): + self._target_db = target_db + self._prior_db = prior_db + self._prior_samples = prior_samples + self._rng = rng + self._startup_delay = startup_delay + + def transform_audio(self, audio_segment): + """Normalizes the input audio using the online Bayesian approach. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + audio_segment.normalize_online_bayesian(self._target_db, self._prior_db, + self._prior_samples, + self._startup_delay) diff --git a/deep_speech_2/data_utils/augmentor/resample.py b/deep_speech_2/data_utils/augmentor/resample.py new file mode 100755 index 0000000000000000000000000000000000000000..8df17f3a869420bca1e4e6c0ae9b4035f7d50d8d --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/resample.py @@ -0,0 +1,33 @@ +"""Contain the resample augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class ResampleAugmentor(AugmentorBase): + """Augmentation model for resampling. + + See more info here: + https://ccrma.stanford.edu/~jos/resample/index.html + + :param rng: Random generator object. + :type rng: random.Random + :param new_sample_rate: New sample rate in Hz. + :type new_sample_rate: int + """ + + def __init__(self, rng, new_sample_rate): + self._new_sample_rate = new_sample_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Resamples the input audio to a target sample rate. + + Note that this is an in-place transformation. + + :param audio: Audio segment to add effects to. + :type audio: AudioSegment|SpeechSegment + """ + audio_segment.resample(self._new_sample_rate) diff --git a/deep_speech_2/data_utils/augmentor/speed_perturb.py b/deep_speech_2/data_utils/augmentor/speed_perturb.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5738bd155a5871817039f5ccb3c4707ff87a6c --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/speed_perturb.py @@ -0,0 +1,47 @@ +"""Contain the speech perturbation augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class SpeedPerturbAugmentor(AugmentorBase): + """Augmentation model for adding speed perturbation. + + See reference paper here: + http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf + + :param rng: Random generator object. + :type rng: random.Random + :param min_speed_rate: Lower bound of new speed rate to sample and should + not be smaller than 0.9. + :type min_speed_rate: float + :param max_speed_rate: Upper bound of new speed rate to sample and should + not be larger than 1.1. + :type max_speed_rate: float + """ + + def __init__(self, rng, min_speed_rate, max_speed_rate): + if min_speed_rate < 0.9: + raise ValueError( + "Sampling speed below 0.9 can cause unnatural effects") + if max_speed_rate > 1.1: + raise ValueError( + "Sampling speed above 1.1 can cause unnatural effects") + self._min_speed_rate = min_speed_rate + self._max_speed_rate = max_speed_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Sample a new speed rate from the given range and + changes the speed of the given audio clip. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + sampled_speed = self._rng.uniform(self._min_speed_rate, + self._max_speed_rate) + audio_segment.change_speed(sampled_speed) diff --git a/deep_speech_2/data_utils/augmentor/volume_perturb.py b/deep_speech_2/data_utils/augmentor/volume_perturb.py index 62631fb041c45350811b2cd2dd78d6758a622db8..758676d558d8e4d77191504d0d7b75cefe020549 100644 --- a/deep_speech_2/data_utils/augmentor/volume_perturb.py +++ b/deep_speech_2/data_utils/augmentor/volume_perturb.py @@ -37,4 +37,4 @@ class VolumePerturbAugmentor(AugmentorBase): :type audio_segment: AudioSegmenet|SpeechSegment """ gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) - audio_segment.apply_gain(gain) + audio_segment.gain_db(gain) diff --git a/deep_speech_2/requirements.txt b/deep_speech_2/requirements.txt old mode 100644 new mode 100755 index 79272e7e6c898226ecb11e80cf9f97efd7a2d6f2..967b4f8c3148c62cd5b7a511567848af6c5c8f93 --- a/deep_speech_2/requirements.txt +++ b/deep_speech_2/requirements.txt @@ -1,2 +1,3 @@ wget==3.2 scipy==0.13.1 +resampy==0.1.5 \ No newline at end of file