augmentation.py 3.6 KB
Newer Older
1
"""Contains the data augmentation pipeline."""
2 3 4 5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import random
8
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
chrisxu2014's avatar
chrisxu2014 已提交
9 10 11
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor
12 13 14


class AugmentationPipeline(object):
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    """Build a pre-processing pipeline with various augmentation models.Such a
    data augmentation pipeline is oftern leveraged to augment the training
    samples to make the model invariant to certain types of perturbations in the
    real world, improving model's generalization ability.

    The pipeline is built according the the augmentation configuration in json
    string, e.g.
    
    .. code-block::
        
        '[{"type": "volume",
           "params": {"min_gain_dBFS": -15,
                      "max_gain_dBFS": 15},
           "prob": 0.5},
          {"type": "speed",
           "params": {"min_speed_rate": 0.8,
                      "max_speed_rate": 1.2},
           "prob": 0.5}
         ]' 

    This augmentation configuration inserts two augmentation models
    into the pipeline, with one is VolumePerturbAugmentor and the other
    SpeedPerturbAugmentor. "prob" indicates the probability of the current
    augmentor to take effect.

    :param augmentation_config: Augmentation configuration in json string.
    :type augmentation_config: str
    :param random_seed: Random seed.
    :type random_seed: int
    :raises ValueError: If the augmentation json config is in incorrect format".
    """

47 48 49 50 51 52
    def __init__(self, augmentation_config, random_seed=0):
        self._rng = random.Random(random_seed)
        self._augmentors, self._rates = self._parse_pipeline_from(
            augmentation_config)

    def transform_audio(self, audio_segment):
53 54 55 56 57 58 59
        """Run the pre-processing pipeline for data augmentation.

        Note that this is an in-place transformation.
        
        :param audio_segment: Audio segment to process.
        :type audio_segment: AudioSegmenet|SpeechSegment
        """
60 61 62 63 64
        for augmentor, rate in zip(self._augmentors, self._rates):
            if self._rng.uniform(0., 1.) <= rate:
                augmentor.transform_audio(audio_segment)

    def _parse_pipeline_from(self, config_json):
65
        """Parse the config json to build a augmentation pipelien."""
66 67
        try:
            configs = json.loads(config_json)
68 69 70 71 72
            augmentors = [
                self._get_augmentor(config["type"], config["params"])
                for config in configs
            ]
            rates = [config["prob"] for config in configs]
73
        except Exception as e:
74
            raise ValueError("Failed to parse the augmentation config json: "
75 76 77 78
                             "%s" % str(e))
        return augmentors, rates

    def _get_augmentor(self, augmentor_type, params):
79 80 81
        """Return an augmentation model by the type name, and pass in params."""
        if augmentor_type == "volume":
            return VolumePerturbAugmentor(self._rng, **params)
chrisxu2014's avatar
chrisxu2014 已提交
82 83 84 85
        if augmentor_type == "speed":
            return SpeedPerturbAugmentor(self._rng, **params)
        if augmentor_type == "resample":
            return ResampleAugmentor(self._rng, **params)
chrisxu2014's avatar
chrisxu2014 已提交
86
        if augmentor_type == "bayesian_normal":
chrisxu2014's avatar
chrisxu2014 已提交
87
            return OnlineBayesianNormalizationAugmentor(self._rng, **params)
88 89
        else:
            raise ValueError("Unknown augmentor type [%s]." % augmentor_type)