diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index cc0564daf6745c7bc353ae33fd6bd32b34fc9cba..7998bfbd74aa6c462fee188e66b823d74b4331b4 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -13,18 +13,29 @@ # limitations under the License. """Contains the data augmentation pipeline.""" import json +from pprint import pformat +from collections.abc import Sequence +from inspect import signature import numpy as np -from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor -from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor -from deepspeech.frontend.augmentor.online_bayesian_normalization import \ - OnlineBayesianNormalizationAugmentor -from deepspeech.frontend.augmentor.resample import ResampleAugmentor -from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor -from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor -from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor -from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor +from deepspeech.frontend.augmentor.base import AugmentorBase +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ["AugmentationPipeline"] + +import_alias = dict( + volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor", + shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor", + speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor", + resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor", + bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor", + noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor", + impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor", + specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", ) class AugmentationPipeline(): @@ -78,20 +89,75 @@ class AugmentationPipeline(): augmentor to take effect. If "prob" is zero, the augmentor does not take effect. - :param augmentation_config: Augmentation configuration in json string. - :type augmentation_config: str - :param random_seed: Random seed. - :type random_seed: int - :raises ValueError: If the augmentation json config is in incorrect format". + Params: + augmentation_config(str): Augmentation configuration in json string. + random_seed(int): Random seed. + train(bool): whether is train mode. + + Raises: + ValueError: If the augmentation json config is in incorrect format". """ - def __init__(self, augmentation_config: str, random_seed=0): + SPEC_TYPES = {'specaug'} + + def __init__(self, augmentation_config: str, random_seed: int=0): self._rng = np.random.RandomState(random_seed) - self._spec_types = ('specaug') - self._augmentors, self._rates = self._parse_pipeline_from( - augmentation_config, 'audio') + self.conf = {'mode': 'sequential', 'process': []} + if augmentation_config: + process = json.loads(augmentation_config) + self.conf['process'] += process + + self._augmentors, self._rates = self._parse_pipeline_from('all') + self._audio_augmentors, self._audio_rates = self._parse_pipeline_from( + 'audio') self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( - augmentation_config, 'feature') + 'feature') + logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}") + + def __call__(self, xs, uttid_list=None, **kwargs): + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx, (func, rate) in enumerate( + zip(self._augmentors, self._rates), 0): + if self._rng.uniform(0., 1.) >= rate: + continue + + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + + try: + if uttid_list is not None and "uttid" in param: + xs = [ + func(x, u, **_kwargs) + for x, u in zip(xs, uttid_list) + ] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logger.fatal("Catch a exception from {}th func: {}".format( + idx, func)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + if is_batch: + return xs + else: + return xs[0] def transform_audio(self, audio_segment): """Run the pre-processing pipeline for data augmentation. @@ -101,7 +167,7 @@ class AugmentationPipeline(): :param audio_segment: Audio segment to process. :type audio_segment: AudioSegmenet|SpeechSegment """ - for augmentor, rate in zip(self._augmentors, self._rates): + for augmentor, rate in zip(self._audio_augmentors, self._audio_rates): if self._rng.uniform(0., 1.) < rate: augmentor.transform_audio(audio_segment) @@ -116,52 +182,41 @@ class AugmentationPipeline(): spec_segment = augmentor.transform_feature(spec_segment) return spec_segment - def _parse_pipeline_from(self, config_json, aug_type='audio'): + def _parse_pipeline_from(self, aug_type='all'): """Parse the config json to build a augmentation pipelien.""" - assert aug_type in ('audio', 'feature'), aug_type - try: - configs = json.loads(config_json) - audio_confs = [] - feature_confs = [] - for config in configs: - if config["type"] in self._spec_types: - feature_confs.append(config) - else: - audio_confs.append(config) - - if aug_type == 'audio': - aug_confs = audio_confs - elif aug_type == 'feature': - aug_confs = feature_confs - - augmentors = [ - self._get_augmentor(config["type"], config["params"]) - for config in aug_confs - ] - rates = [config["prob"] for config in aug_confs] - - except Exception as e: - raise ValueError("Failed to parse the augmentation config json: " - "%s" % str(e)) + assert aug_type in ('audio', 'feature', 'all'), aug_type + audio_confs = [] + feature_confs = [] + all_confs = [] + for config in self.conf['process']: + all_confs.append(config) + if config["type"] in self.SPEC_TYPES: + feature_confs.append(config) + else: + audio_confs.append(config) + + if aug_type == 'audio': + aug_confs = audio_confs + elif aug_type == 'feature': + aug_confs = feature_confs + elif aug_type == 'all': + aug_confs = all_confs + else: + raise ValueError(f"Not support: {aug_type}") + + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in aug_confs + ] + rates = [config["prob"] for config in aug_confs] return augmentors, rates def _get_augmentor(self, augmentor_type, params): """Return an augmentation model by the type name, and pass in params.""" - if augmentor_type == "volume": - return VolumePerturbAugmentor(self._rng, **params) - elif augmentor_type == "shift": - return ShiftPerturbAugmentor(self._rng, **params) - elif augmentor_type == "speed": - return SpeedPerturbAugmentor(self._rng, **params) - elif augmentor_type == "resample": - return ResampleAugmentor(self._rng, **params) - elif augmentor_type == "bayesian_normal": - return OnlineBayesianNormalizationAugmentor(self._rng, **params) - elif augmentor_type == "noise": - return NoisePerturbAugmentor(self._rng, **params) - elif augmentor_type == "impulse": - return ImpulseResponseAugmentor(self._rng, **params) - elif augmentor_type == "specaug": - return SpecAugmentor(self._rng, **params) - else: + class_obj = dynamic_import(augmentor_type, import_alias) + assert issubclass(class_obj, AugmentorBase) + try: + obj = class_obj(self._rng, **params) + except Exception: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) + return obj diff --git a/deepspeech/frontend/augmentor/base.py b/deepspeech/frontend/augmentor/base.py index e6f5c1e9f4c39e6964449912b4d322cfa51465d4..18d003c0b125c76a6016e830227d4ee3f5ddc19e 100644 --- a/deepspeech/frontend/augmentor/base.py +++ b/deepspeech/frontend/augmentor/base.py @@ -28,6 +28,10 @@ class AugmentorBase(): def __init__(self): pass + @abstractmethod + def __call__(self, xs): + raise NotImplementedError("AugmentorBase: Not impl __call__") + @abstractmethod def transform_audio(self, audio_segment): """Adds various effects to the input audio segment. Such effects @@ -40,7 +44,7 @@ class AugmentorBase(): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegmenet|SpeechSegment """ - raise NotImplementedError + raise NotImplementedError("AugmentorBase: Not impl transform_audio") @abstractmethod def transform_feature(self, spec_segment): @@ -52,4 +56,4 @@ class AugmentorBase(): Args: spec_segment (Spectrogram): Spectrogram segment to add effects to. """ - raise NotImplementedError + raise NotImplementedError("AugmentorBase: Not impl transform_feature") diff --git a/deepspeech/frontend/augmentor/impulse_response.py b/deepspeech/frontend/augmentor/impulse_response.py index fbd617b42e01d823456fa7568ba54e3be2dd2c09..818251ed8c82dfc21547fea06ae21ee05c7c8d38 100644 --- a/deepspeech/frontend/augmentor/impulse_response.py +++ b/deepspeech/frontend/augmentor/impulse_response.py @@ -30,6 +30,12 @@ class ImpulseResponseAugmentor(AugmentorBase): self._rng = rng self._impulse_manifest = read_manifest(impulse_manifest_path) + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Add impulse response effect. diff --git a/deepspeech/frontend/augmentor/noise_perturb.py b/deepspeech/frontend/augmentor/noise_perturb.py index b3c07f5c1c0c84f3ffe7eaaf7d377cab6cc481c4..790b0c39682933c1feb2c6fab90ea0c2e8d189c6 100644 --- a/deepspeech/frontend/augmentor/noise_perturb.py +++ b/deepspeech/frontend/augmentor/noise_perturb.py @@ -36,6 +36,12 @@ class NoisePerturbAugmentor(AugmentorBase): self._rng = rng self._noise_manifest = read_manifest(manifest_path=noise_manifest_path) + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Add background noise audio. diff --git a/deepspeech/frontend/augmentor/online_bayesian_normalization.py b/deepspeech/frontend/augmentor/online_bayesian_normalization.py index 5af3b9b03eb9e3805e6ac303c66f7949a32f87b6..0f9d3ef6fbfba1b4c6895996d180ce33d0c18891 100644 --- a/deepspeech/frontend/augmentor/online_bayesian_normalization.py +++ b/deepspeech/frontend/augmentor/online_bayesian_normalization.py @@ -44,6 +44,12 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase): self._rng = rng self._startup_delay = startup_delay + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Normalizes the input audio using the online Bayesian approach. diff --git a/deepspeech/frontend/augmentor/resample.py b/deepspeech/frontend/augmentor/resample.py index 9afce635d00f7821ff9cd7cfd1cb3fe31409853d..509fe003df11503d88403bfd0afd870665a87397 100644 --- a/deepspeech/frontend/augmentor/resample.py +++ b/deepspeech/frontend/augmentor/resample.py @@ -31,6 +31,12 @@ class ResampleAugmentor(AugmentorBase): self._new_sample_rate = new_sample_rate self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Resamples the input audio to a target sample rate. diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py index 9cc3fe2d0df6e2ce66fe5d40b746a997d2daf2be..8b7439fe58aa6626d4b7873bb1511f1c59a8dd16 100644 --- a/deepspeech/frontend/augmentor/shift_perturb.py +++ b/deepspeech/frontend/augmentor/shift_perturb.py @@ -31,6 +31,12 @@ class ShiftPerturbAugmentor(AugmentorBase): self._max_shift_ms = max_shift_ms self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Shift audio. diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 956975c6b1375e28f0a184c391d52ae22e064fa8..26c94d41639ce280309ba44b039f49dd15f0a862 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains the volume perturb augmentation model.""" +import random + import numpy as np +from PIL import Image +from PIL.Image import BICUBIC from deepspeech.frontend.augmentor.base import AugmentorBase from deepspeech.utils.log import Log @@ -25,10 +29,10 @@ class SpecAugmentor(AugmentorBase): SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition https://arxiv.org/abs/1904.08779 - + SpecAugment on Large Scale Datasets https://arxiv.org/abs/1912.05533 - + """ def __init__(self, @@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase): adaptive_number_ratio=0, adaptive_size_ratio=0, max_n_time_masks=20, - **kwargs): + replace_with_zero=True, + warp_mode='PIL'): """SpecAugment class. Args: rng (random.Random): random generator object. @@ -55,17 +60,22 @@ class SpecAugmentor(AugmentorBase): adaptive_number_ratio (float): adaptive multiplicity ratio for time masking adaptive_size_ratio (float): adaptive size ratio for time masking max_n_time_masks (int): maximum number of time masking + replace_with_zero (bool): pad zero on mask if true else use mean + warp_mode (str): "PIL" (default, fast, not differentiable) + or "sparse_image_warp" (slow, differentiable) """ super().__init__() self._rng = rng + self.inplace = True + self.replace_with_zero = replace_with_zero + self.mode = warp_mode self.W = W self.F = F self.T = T self.n_freq_masks = n_freq_masks self.n_time_masks = n_time_masks self.p = p - #logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}") # adaptive SpecAugment self.adaptive_number_ratio = adaptive_number_ratio @@ -122,21 +132,86 @@ class SpecAugmentor(AugmentorBase): def time_mask(self): return self._time_mask - def time_warp(self, xs, W=40): - raise NotImplementedError + def __repr__(self): + return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" + + def time_warp(self, x, mode='PIL'): + """time warp for spec augment + move random center frame by the random width ~ uniform(-window, window) + + Args: + x (np.ndarray): spectrogram (time, freq) + mode (str): PIL or sparse_image_warp + + Raises: + NotImplementedError: [description] + NotImplementedError: [description] + + Returns: + np.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp = self.W + if window == 0: + return x + + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), + BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), + BICUBIC) + if self.inplace: + x[:warped] = left + x[warped:] = right + return x + return np.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + raise NotImplementedError('sparse_image_warp') + else: + raise NotImplementedError( + "unknown resize mode: " + mode + + ", choose one from (PIL, sparse_image_warp).") + + def mask_freq(self, x, replace_with_zero=False): + """freq mask + + Args: + x (np.ndarray): spectrogram (time, freq) + replace_with_zero (bool, optional): Defaults to False. - def mask_freq(self, xs, replace_with_zero=False): - n_bins = xs.shape[0] + Returns: + np.ndarray: freq mask spectrogram (time, freq) + """ + n_bins = x.shape[1] for i in range(0, self.n_freq_masks): f = int(self._rng.uniform(low=0, high=self.F)) f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) - xs[f_0:f_0 + f, :] = 0 assert f_0 <= f_0 + f + if replace_with_zero: + x[:, f_0:f_0 + f] = 0 + else: + x[:, f_0:f_0 + f] = x.mean() self._freq_mask = (f_0, f_0 + f) - return xs + return x - def mask_time(self, xs, replace_with_zero=False): - n_frames = xs.shape[1] + def mask_time(self, x, replace_with_zero=False): + """time mask + + Args: + x (np.ndarray): spectrogram (time, freq) + replace_with_zero (bool, optional): Defaults to False. + + Returns: + np.ndarray: time mask spectrogram (time, freq) + """ + n_frames = x.shape[0] if self.adaptive_number_ratio > 0: n_masks = int(n_frames * self.adaptive_number_ratio) @@ -153,19 +228,29 @@ class SpecAugmentor(AugmentorBase): t = int(self._rng.uniform(low=0, high=T)) t = min(t, int(n_frames * self.p)) t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) - xs[:, t_0:t_0 + t] = 0 assert t_0 <= t_0 + t + if replace_with_zero: + x[t_0:t_0 + t, :] = 0 + else: + x[t_0:t_0 + t, :] = x.mean() self._time_mask = (t_0, t_0 + t) - return xs + return x + + def __call__(self, x, train=True): + if not train: + return x + return self.transform_feature(x) - def transform_feature(self, xs: np.ndarray): + def transform_feature(self, x: np.ndarray): """ Args: - xs (FloatTensor): `[F, T]` + x (np.ndarray): `[T, F]` Returns: - xs (FloatTensor): `[F, T]` + x (np.ndarray): `[T, F]` """ - # xs = self.time_warp(xs) - xs = self.mask_freq(xs) - xs = self.mask_time(xs) - return xs + assert isinstance(x, np.ndarray) + assert x.ndim == 2 + x = self.time_warp(x, self.mode) + x = self.mask_freq(x, self.replace_with_zero) + x = self.mask_time(x, self.replace_with_zero) + return x diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py index d0977c13197109a792bbfd8f294214627b220610..ce8dfde0a674f39459bec31169aeb614e842052b 100644 --- a/deepspeech/frontend/augmentor/speed_perturb.py +++ b/deepspeech/frontend/augmentor/speed_perturb.py @@ -79,6 +79,12 @@ class SpeedPerturbAugmentor(AugmentorBase): self._rates = np.linspace( self._min_rate, self._max_rate, self._num_rates, endpoint=True) + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Sample a new speed rate from the given range and changes the speed of the given audio clip. diff --git a/deepspeech/frontend/augmentor/volume_perturb.py b/deepspeech/frontend/augmentor/volume_perturb.py index 0d76e7a054a3eb6a4fa6016b789a9a9e4d020fe7..70cb2889706c355048db1df5e18f1bb155c3ffd1 100644 --- a/deepspeech/frontend/augmentor/volume_perturb.py +++ b/deepspeech/frontend/augmentor/volume_perturb.py @@ -37,6 +37,12 @@ class VolumePerturbAugmentor(AugmentorBase): self._max_gain_dBFS = max_gain_dBFS self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return x + self.transform_audio(x) + return x + def transform_audio(self, audio_segment): """Change audio loadness. diff --git a/examples/librispeech/s1/conf/augmentation.json b/examples/librispeech/s1/conf/augmentation.json index c1078393d2f2f57fbcb3b48ce0975c2612c39dcb..ac35d6e43996e482041b439af8bf196438b47c9e 100644 --- a/examples/librispeech/s1/conf/augmentation.json +++ b/examples/librispeech/s1/conf/augmentation.json @@ -19,15 +19,17 @@ { "type": "specaug", "params": { + "W": 0, + "warp_mode": "PIL", "F": 10, "T": 50, "n_freq_masks": 2, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 }