“22a5e478f3b6ecc0e43d31abce39a686b6331165”上不存在“paddle/legacy/gserver/layers/NormProjectionLayer.cpp”
提交 a061d1b8 编写于 作者: chrisxu2014's avatar chrisxu2014

add augmentor class

上级 1371a9a0
...@@ -308,7 +308,7 @@ class AudioSegment(object): ...@@ -308,7 +308,7 @@ class AudioSegment(object):
prior_mean_squared = 10.**(prior_db / 10.) prior_mean_squared = 10.**(prior_db / 10.)
prior_sum_of_squares = prior_mean_squared * prior_samples prior_sum_of_squares = prior_mean_squared * prior_samples
cumsum_of_squares = np.cumsum(self.samples**2) cumsum_of_squares = np.cumsum(self.samples**2)
sample_count = np.arange(len(self.num_samples)) + 1 sample_count = np.arange(self.num_samples) + 1
if startup_sample_idx > 0: if startup_sample_idx > 0:
cumsum_of_squares[:startup_sample_idx] = \ cumsum_of_squares[:startup_sample_idx] = \
cumsum_of_squares[startup_sample_idx] cumsum_of_squares[startup_sample_idx]
......
...@@ -6,6 +6,9 @@ from __future__ import print_function ...@@ -6,6 +6,9 @@ from __future__ import print_function
import json import json
import random import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor
class AugmentationPipeline(object): class AugmentationPipeline(object):
...@@ -76,5 +79,11 @@ class AugmentationPipeline(object): ...@@ -76,5 +79,11 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume": if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params) return VolumePerturbAugmentor(self._rng, **params)
if augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
if augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
if augmentor_type == "baysian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
else: else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
"""Contain the online bayesian normalization augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
"""Augmentation model for adding online bayesian normalization.
:param rng: Random generator object.
:type rng: random.Random
:param target_db: Target RMS value in decibels.
:type target_db: float
:param prior_db: Prior RMS estimate in decibels.
:type prior_db: float
:param prior_samples: Prior strength in number of samples.
:type prior_samples: int
:param startup_delay: Default 0.0s. If provided, this function will
accrue statistics for the first startup_delay
seconds before applying online normalization.
:type starup_delay: float.
"""
def __init__(self,
rng,
target_db,
prior_db,
prior_samples,
startup_delay=0.0):
self._target_db = target_db
self._prior_db = prior_db
self._prior_samples = prior_samples
self._startup_delay = startup_delay
self._rng = rng
self._startup_delay=startup_delay
def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment
"""
audio_segment.normalize_online_bayesian(self._target_db,
self._prior_db,
self._prior_samples,
self._startup_delay)
"""Contain the resample augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class ResampleAugmentor(AugmentorBase):
"""Augmentation model for resampling.
:param rng: Random generator object.
:type rng: random.Random
:param new_sample_rate: New sample rate in Hz
:type new_sample_rate: int
"""
def __init__(self, rng, new_sample_rate):
self._new_sample_rate = new_sample_rate
self._rng = rng
def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate.
Note that this is an in-place transformation.
:param audio: Audio segment to add effects to.
:type audio: AudioSegment|SpeechSegment
"""
audio_segment.resample(self._new_sample_rate)
\ No newline at end of file
"""Contain the speech perturbation augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class SpeedPerturbAugmentor(AugmentorBase):
"""Augmentation model for adding speed perturbation.
See reference paper here:
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
:param rng: Random generator object.
:type rng: random.Random
:param min_speed_rate: Lower bound of new speed rate to sample.
:type min_speed_rate: float
:param max_speed_rate: Upper bound of new speed rate to sample.
:type max_speed_rate: float
"""
def __init__(self, rng, min_speed_rate, max_speed_rate):
if (min_speed_rate < 0.5):
raise ValueError("Sampling speed below 0.9 can cause unnatural effects")
if (max_speed_rate > 1.5):
raise ValueError("Sampling speed above 1.1 can cause unnatural effects")
self._min_speed_rate = min_speed_rate
self._max_speed_rate = max_speed_rate
self._rng = rng
def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment
"""
sampled_speed = self._rng.uniform(self._min_speed_rate, self._max_speed_rate)
audio_segment.change_speed(sampled_speed)
...@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase): ...@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to. :param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain) audio_segment.apply_gain(gain)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册