提交 0c9fbaf7 编写于 作者: H Hui Zhang

refactor augmentor

上级 756be8fb
......@@ -13,18 +13,29 @@
# limitations under the License.
"""Contains the data augmentation pipeline."""
import json
from pprint import pformat
from collections.abc import Sequence
from inspect import signature
import numpy as np
from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor
from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor
from deepspeech.frontend.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor
from deepspeech.frontend.augmentor.resample import ResampleAugmentor
from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor
from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor
from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["AugmentationPipeline"]
import_alias = dict(
volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor",
bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", )
class AugmentationPipeline():
......@@ -78,20 +89,75 @@ class AugmentationPipeline():
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
:param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str
:param random_seed: Random seed.
:type random_seed: int
:raises ValueError: If the augmentation json config is in incorrect format".
Params:
augmentation_config(str): Augmentation configuration in json string.
random_seed(int): Random seed.
train(bool): whether is train mode.
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
def __init__(self, augmentation_config: str, random_seed=0):
SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug')
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config, 'audio')
self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config:
process = json.loads(augmentation_config)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
augmentation_config, 'feature')
'feature')
logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx, (func, rate) in enumerate(
zip(self._augmentors, self._rates), 0):
if self._rng.uniform(0., 1.) >= rate:
continue
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logger.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
......@@ -101,7 +167,7 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._augmentors, self._rates):
for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
......@@ -116,52 +182,41 @@ class AugmentationPipeline():
spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment
def _parse_pipeline_from(self, config_json, aug_type='audio'):
def _parse_pipeline_from(self, aug_type='all'):
"""Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature'), aug_type
try:
configs = json.loads(config_json)
audio_confs = []
feature_confs = []
for config in configs:
if config["type"] in self._spec_types:
feature_confs.append(config)
else:
audio_confs.append(config)
if aug_type == 'audio':
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in aug_confs
]
rates = [config["prob"] for config in aug_confs]
except Exception as e:
raise ValueError("Failed to parse the augmentation config json: "
"%s" % str(e))
assert aug_type in ('audio', 'feature', 'all'), aug_type
audio_confs = []
feature_confs = []
all_confs = []
for config in self.conf['process']:
all_confs.append(config)
if config["type"] in self.SPEC_TYPES:
feature_confs.append(config)
else:
audio_confs.append(config)
if aug_type == 'audio':
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
elif aug_type == 'all':
aug_confs = all_confs
else:
raise ValueError(f"Not support: {aug_type}")
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in aug_confs
]
rates = [config["prob"] for config in aug_confs]
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
elif augmentor_type == "specaug":
return SpecAugmentor(self._rng, **params)
else:
class_obj = dynamic_import(augmentor_type, import_alias)
assert issubclass(class_obj, AugmentorBase)
try:
obj = class_obj(self._rng, **params)
except Exception:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
return obj
......@@ -28,6 +28,10 @@ class AugmentorBase():
def __init__(self):
pass
@abstractmethod
def __call__(self, xs):
raise NotImplementedError("AugmentorBase: Not impl __call__")
@abstractmethod
def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects
......@@ -40,7 +44,7 @@ class AugmentorBase():
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
raise NotImplementedError
raise NotImplementedError("AugmentorBase: Not impl transform_audio")
@abstractmethod
def transform_feature(self, spec_segment):
......@@ -52,4 +56,4 @@ class AugmentorBase():
Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to.
"""
raise NotImplementedError
raise NotImplementedError("AugmentorBase: Not impl transform_feature")
......@@ -30,6 +30,12 @@ class ImpulseResponseAugmentor(AugmentorBase):
self._rng = rng
self._impulse_manifest = read_manifest(impulse_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Add impulse response effect.
......
......@@ -36,6 +36,12 @@ class NoisePerturbAugmentor(AugmentorBase):
self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Add background noise audio.
......
......@@ -44,6 +44,12 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
self._rng = rng
self._startup_delay = startup_delay
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach.
......
......@@ -31,6 +31,12 @@ class ResampleAugmentor(AugmentorBase):
self._new_sample_rate = new_sample_rate
self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate.
......
......@@ -31,6 +31,12 @@ class ShiftPerturbAugmentor(AugmentorBase):
self._max_shift_ms = max_shift_ms
self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Shift audio.
......
......@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the volume perturb augmentation model."""
import random
import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.log import Log
......@@ -25,10 +29,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def __init__(self,
......@@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20,
**kwargs):
replace_with_zero=True,
warp_mode='PIL'):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
......@@ -55,17 +60,22 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
super().__init__()
self._rng = rng
self.inplace = True
self.replace_with_zero = replace_with_zero
self.mode = warp_mode
self.W = W
self.F = F
self.T = T
self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks
self.p = p
#logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}")
# adaptive SpecAugment
self.adaptive_number_ratio = adaptive_number_ratio
......@@ -122,21 +132,86 @@ class SpecAugmentor(AugmentorBase):
def time_mask(self):
return self._time_mask
def time_warp(self, xs, W=40):
raise NotImplementedError
def __repr__(self):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(self, x, mode='PIL'):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp = self.W
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right
return x
return np.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
raise NotImplementedError('sparse_image_warp')
else:
raise NotImplementedError(
"unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
def mask_freq(self, x, replace_with_zero=False):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
def mask_freq(self, xs, replace_with_zero=False):
n_bins = xs.shape[0]
Returns:
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins = x.shape[1]
for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f
if replace_with_zero:
x[:, f_0:f_0 + f] = 0
else:
x[:, f_0:f_0 + f] = x.mean()
self._freq_mask = (f_0, f_0 + f)
return xs
return x
def mask_time(self, xs, replace_with_zero=False):
n_frames = xs.shape[1]
def mask_time(self, x, replace_with_zero=False):
"""time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames = x.shape[0]
if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio)
......@@ -153,19 +228,29 @@ class SpecAugmentor(AugmentorBase):
t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t
if replace_with_zero:
x[t_0:t_0 + t, :] = 0
else:
x[t_0:t_0 + t, :] = x.mean()
self._time_mask = (t_0, t_0 + t)
return xs
return x
def __call__(self, x, train=True):
if not train:
return x
return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray):
def transform_feature(self, x: np.ndarray):
"""
Args:
xs (FloatTensor): `[F, T]`
x (np.ndarray): `[T, F]`
Returns:
xs (FloatTensor): `[F, T]`
x (np.ndarray): `[T, F]`
"""
# xs = self.time_warp(xs)
xs = self.mask_freq(xs)
xs = self.mask_time(xs)
return xs
assert isinstance(x, np.ndarray)
assert x.ndim == 2
x = self.time_warp(x, self.mode)
x = self.mask_freq(x, self.replace_with_zero)
x = self.mask_time(x, self.replace_with_zero)
return x
......@@ -79,6 +79,12 @@ class SpeedPerturbAugmentor(AugmentorBase):
self._rates = np.linspace(
self._min_rate, self._max_rate, self._num_rates, endpoint=True)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
......
......@@ -37,6 +37,12 @@ class VolumePerturbAugmentor(AugmentorBase):
self._max_gain_dBFS = max_gain_dBFS
self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Change audio loadness.
......
......@@ -19,15 +19,17 @@
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册