“3c7cde0ce1ec9669ecf6954bb9c4ab0daf94bf30”上不存在“mobile/src/operators/quantize_op.cpp”
提交 0c9fbaf7 编写于 作者: H Hui Zhang

refactor augmentor

上级 756be8fb
...@@ -13,18 +13,29 @@ ...@@ -13,18 +13,29 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
from pprint import pformat
from collections.abc import Sequence
from inspect import signature
import numpy as np import numpy as np
from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.frontend.augmentor.online_bayesian_normalization import \ from deepspeech.utils.log import Log
OnlineBayesianNormalizationAugmentor
from deepspeech.frontend.augmentor.resample import ResampleAugmentor logger = Log(__name__).getlog()
from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor __all__ = ["AugmentationPipeline"]
from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor import_alias = dict(
volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor",
bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", )
class AugmentationPipeline(): class AugmentationPipeline():
...@@ -78,20 +89,75 @@ class AugmentationPipeline(): ...@@ -78,20 +89,75 @@ class AugmentationPipeline():
augmentor to take effect. If "prob" is zero, the augmentor does not take augmentor to take effect. If "prob" is zero, the augmentor does not take
effect. effect.
:param augmentation_config: Augmentation configuration in json string. Params:
:type augmentation_config: str augmentation_config(str): Augmentation configuration in json string.
:param random_seed: Random seed. random_seed(int): Random seed.
:type random_seed: int train(bool): whether is train mode.
:raises ValueError: If the augmentation json config is in incorrect format".
Raises:
ValueError: If the augmentation json config is in incorrect format".
""" """
def __init__(self, augmentation_config: str, random_seed=0): SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed) self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug') self.conf = {'mode': 'sequential', 'process': []}
self._augmentors, self._rates = self._parse_pipeline_from( if augmentation_config:
augmentation_config, 'audio') process = json.loads(augmentation_config)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
augmentation_config, 'feature') 'feature')
logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx, (func, rate) in enumerate(
zip(self._augmentors, self._rates), 0):
if self._rng.uniform(0., 1.) >= rate:
continue
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logger.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation. """Run the pre-processing pipeline for data augmentation.
...@@ -101,7 +167,7 @@ class AugmentationPipeline(): ...@@ -101,7 +167,7 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process. :param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
for augmentor, rate in zip(self._augmentors, self._rates): for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
if self._rng.uniform(0., 1.) < rate: if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment) augmentor.transform_audio(audio_segment)
...@@ -116,52 +182,41 @@ class AugmentationPipeline(): ...@@ -116,52 +182,41 @@ class AugmentationPipeline():
spec_segment = augmentor.transform_feature(spec_segment) spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment return spec_segment
def _parse_pipeline_from(self, config_json, aug_type='audio'): def _parse_pipeline_from(self, aug_type='all'):
"""Parse the config json to build a augmentation pipelien.""" """Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature'), aug_type assert aug_type in ('audio', 'feature', 'all'), aug_type
try: audio_confs = []
configs = json.loads(config_json) feature_confs = []
audio_confs = [] all_confs = []
feature_confs = [] for config in self.conf['process']:
for config in configs: all_confs.append(config)
if config["type"] in self._spec_types: if config["type"] in self.SPEC_TYPES:
feature_confs.append(config) feature_confs.append(config)
else: else:
audio_confs.append(config) audio_confs.append(config)
if aug_type == 'audio': if aug_type == 'audio':
aug_confs = audio_confs aug_confs = audio_confs
elif aug_type == 'feature': elif aug_type == 'feature':
aug_confs = feature_confs aug_confs = feature_confs
elif aug_type == 'all':
augmentors = [ aug_confs = all_confs
self._get_augmentor(config["type"], config["params"]) else:
for config in aug_confs raise ValueError(f"Not support: {aug_type}")
]
rates = [config["prob"] for config in aug_confs] augmentors = [
self._get_augmentor(config["type"], config["params"])
except Exception as e: for config in aug_confs
raise ValueError("Failed to parse the augmentation config json: " ]
"%s" % str(e)) rates = [config["prob"] for config in aug_confs]
return augmentors, rates return augmentors, rates
def _get_augmentor(self, augmentor_type, params): def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume": class_obj = dynamic_import(augmentor_type, import_alias)
return VolumePerturbAugmentor(self._rng, **params) assert issubclass(class_obj, AugmentorBase)
elif augmentor_type == "shift": try:
return ShiftPerturbAugmentor(self._rng, **params) obj = class_obj(self._rng, **params)
elif augmentor_type == "speed": except Exception:
return SpeedPerturbAugmentor(self._rng, **params)
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
elif augmentor_type == "specaug":
return SpecAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
return obj
...@@ -28,6 +28,10 @@ class AugmentorBase(): ...@@ -28,6 +28,10 @@ class AugmentorBase():
def __init__(self): def __init__(self):
pass pass
@abstractmethod
def __call__(self, xs):
raise NotImplementedError("AugmentorBase: Not impl __call__")
@abstractmethod @abstractmethod
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects """Adds various effects to the input audio segment. Such effects
...@@ -40,7 +44,7 @@ class AugmentorBase(): ...@@ -40,7 +44,7 @@ class AugmentorBase():
:param audio_segment: Audio segment to add effects to. :param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
raise NotImplementedError raise NotImplementedError("AugmentorBase: Not impl transform_audio")
@abstractmethod @abstractmethod
def transform_feature(self, spec_segment): def transform_feature(self, spec_segment):
...@@ -52,4 +56,4 @@ class AugmentorBase(): ...@@ -52,4 +56,4 @@ class AugmentorBase():
Args: Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to. spec_segment (Spectrogram): Spectrogram segment to add effects to.
""" """
raise NotImplementedError raise NotImplementedError("AugmentorBase: Not impl transform_feature")
...@@ -30,6 +30,12 @@ class ImpulseResponseAugmentor(AugmentorBase): ...@@ -30,6 +30,12 @@ class ImpulseResponseAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._impulse_manifest = read_manifest(impulse_manifest_path) self._impulse_manifest = read_manifest(impulse_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Add impulse response effect. """Add impulse response effect.
......
...@@ -36,6 +36,12 @@ class NoisePerturbAugmentor(AugmentorBase): ...@@ -36,6 +36,12 @@ class NoisePerturbAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path) self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Add background noise audio. """Add background noise audio.
......
...@@ -44,6 +44,12 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase): ...@@ -44,6 +44,12 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._startup_delay = startup_delay self._startup_delay = startup_delay
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach. """Normalizes the input audio using the online Bayesian approach.
......
...@@ -31,6 +31,12 @@ class ResampleAugmentor(AugmentorBase): ...@@ -31,6 +31,12 @@ class ResampleAugmentor(AugmentorBase):
self._new_sample_rate = new_sample_rate self._new_sample_rate = new_sample_rate
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate. """Resamples the input audio to a target sample rate.
......
...@@ -31,6 +31,12 @@ class ShiftPerturbAugmentor(AugmentorBase): ...@@ -31,6 +31,12 @@ class ShiftPerturbAugmentor(AugmentorBase):
self._max_shift_ms = max_shift_ms self._max_shift_ms = max_shift_ms
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Shift audio. """Shift audio.
......
...@@ -12,7 +12,11 @@ ...@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
import random
import numpy as np import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.frontend.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
...@@ -25,10 +29,10 @@ class SpecAugmentor(AugmentorBase): ...@@ -25,10 +29,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779 https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533 https://arxiv.org/abs/1912.05533
""" """
def __init__(self, def __init__(self,
...@@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase): ...@@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio=0, adaptive_number_ratio=0,
adaptive_size_ratio=0, adaptive_size_ratio=0,
max_n_time_masks=20, max_n_time_masks=20,
**kwargs): replace_with_zero=True,
warp_mode='PIL'):
"""SpecAugment class. """SpecAugment class.
Args: Args:
rng (random.Random): random generator object. rng (random.Random): random generator object.
...@@ -55,17 +60,22 @@ class SpecAugmentor(AugmentorBase): ...@@ -55,17 +60,22 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
""" """
super().__init__() super().__init__()
self._rng = rng self._rng = rng
self.inplace = True
self.replace_with_zero = replace_with_zero
self.mode = warp_mode
self.W = W self.W = W
self.F = F self.F = F
self.T = T self.T = T
self.n_freq_masks = n_freq_masks self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks self.n_time_masks = n_time_masks
self.p = p self.p = p
#logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}")
# adaptive SpecAugment # adaptive SpecAugment
self.adaptive_number_ratio = adaptive_number_ratio self.adaptive_number_ratio = adaptive_number_ratio
...@@ -122,21 +132,86 @@ class SpecAugmentor(AugmentorBase): ...@@ -122,21 +132,86 @@ class SpecAugmentor(AugmentorBase):
def time_mask(self): def time_mask(self):
return self._time_mask return self._time_mask
def time_warp(self, xs, W=40): def __repr__(self):
raise NotImplementedError return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(self, x, mode='PIL'):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp = self.W
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right
return x
return np.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
raise NotImplementedError('sparse_image_warp')
else:
raise NotImplementedError(
"unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
def mask_freq(self, x, replace_with_zero=False):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
def mask_freq(self, xs, replace_with_zero=False): Returns:
n_bins = xs.shape[0] np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins = x.shape[1]
for i in range(0, self.n_freq_masks): for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F)) f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f assert f_0 <= f_0 + f
if replace_with_zero:
x[:, f_0:f_0 + f] = 0
else:
x[:, f_0:f_0 + f] = x.mean()
self._freq_mask = (f_0, f_0 + f) self._freq_mask = (f_0, f_0 + f)
return xs return x
def mask_time(self, xs, replace_with_zero=False): def mask_time(self, x, replace_with_zero=False):
n_frames = xs.shape[1] """time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames = x.shape[0]
if self.adaptive_number_ratio > 0: if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio) n_masks = int(n_frames * self.adaptive_number_ratio)
...@@ -153,19 +228,29 @@ class SpecAugmentor(AugmentorBase): ...@@ -153,19 +228,29 @@ class SpecAugmentor(AugmentorBase):
t = int(self._rng.uniform(low=0, high=T)) t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p)) t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t assert t_0 <= t_0 + t
if replace_with_zero:
x[t_0:t_0 + t, :] = 0
else:
x[t_0:t_0 + t, :] = x.mean()
self._time_mask = (t_0, t_0 + t) self._time_mask = (t_0, t_0 + t)
return xs return x
def __call__(self, x, train=True):
if not train:
return x
return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray): def transform_feature(self, x: np.ndarray):
""" """
Args: Args:
xs (FloatTensor): `[F, T]` x (np.ndarray): `[T, F]`
Returns: Returns:
xs (FloatTensor): `[F, T]` x (np.ndarray): `[T, F]`
""" """
# xs = self.time_warp(xs) assert isinstance(x, np.ndarray)
xs = self.mask_freq(xs) assert x.ndim == 2
xs = self.mask_time(xs) x = self.time_warp(x, self.mode)
return xs x = self.mask_freq(x, self.replace_with_zero)
x = self.mask_time(x, self.replace_with_zero)
return x
...@@ -79,6 +79,12 @@ class SpeedPerturbAugmentor(AugmentorBase): ...@@ -79,6 +79,12 @@ class SpeedPerturbAugmentor(AugmentorBase):
self._rates = np.linspace( self._rates = np.linspace(
self._min_rate, self._max_rate, self._num_rates, endpoint=True) self._min_rate, self._max_rate, self._num_rates, endpoint=True)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and """Sample a new speed rate from the given range and
changes the speed of the given audio clip. changes the speed of the given audio clip.
......
...@@ -37,6 +37,12 @@ class VolumePerturbAugmentor(AugmentorBase): ...@@ -37,6 +37,12 @@ class VolumePerturbAugmentor(AugmentorBase):
self._max_gain_dBFS = max_gain_dBFS self._max_gain_dBFS = max_gain_dBFS
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return x
self.transform_audio(x)
return x
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Change audio loadness. """Change audio loadness.
......
...@@ -19,15 +19,17 @@ ...@@ -19,15 +19,17 @@
{ {
"type": "specaug", "type": "specaug",
"params": { "params": {
"W": 0,
"warp_mode": "PIL",
"F": 10, "F": 10,
"T": 50, "T": 50,
"n_freq_masks": 2, "n_freq_masks": 2,
"n_time_masks": 2, "n_time_masks": 2,
"p": 1.0, "p": 1.0,
"W": 80,
"adaptive_number_ratio": 0, "adaptive_number_ratio": 0,
"adaptive_size_ratio": 0, "adaptive_size_ratio": 0,
"max_n_time_masks": 20 "max_n_time_masks": 20,
"replace_with_zero": true
}, },
"prob": 1.0 "prob": 1.0
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册