提交 99e819e8 编写于 作者: X Xinghai Sun

Add ImpulseResponseAugmentor and augmentation.config file.

上级 ad82c877
[
{
"type": "noise",
"params": {"min_snr_dB": 50,
"max_snr_dB": 50,
"noise_manifest": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 0.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
...@@ -204,7 +204,7 @@ class AudioSegment(object): ...@@ -204,7 +204,7 @@ class AudioSegment(object):
:raise ValueError: If the sample rates of the two segments are not :raise ValueError: If the sample rates of the two segments are not
equal, or if the lengths of segments don't match. equal, or if the lengths of segments don't match.
""" """
if type(self) != type(other): if isinstance(other, type(self)):
raise TypeError("Cannot add segments of different types: %s " raise TypeError("Cannot add segments of different types: %s "
"and %s." % (type(self), type(other))) "and %s." % (type(self), type(other)))
if self._sample_rate != other._sample_rate: if self._sample_rate != other._sample_rate:
...@@ -231,7 +231,7 @@ class AudioSegment(object): ...@@ -231,7 +231,7 @@ class AudioSegment(object):
Note that this is an in-place transformation. Note that this is an in-place transformation.
:param gain: Gain in decibels to apply to samples. :param gain: Gain in decibels to apply to samples.
:type gain: float :type gain: float|1darray
""" """
self._samples *= 10.**(gain / 20.) self._samples *= 10.**(gain / 20.)
...@@ -457,9 +457,9 @@ class AudioSegment(object): ...@@ -457,9 +457,9 @@ class AudioSegment(object):
audio segments when resample is not allowed. audio segments when resample is not allowed.
""" """
if allow_resample and self.sample_rate != impulse_segment.sample_rate: if allow_resample and self.sample_rate != impulse_segment.sample_rate:
impulse_segment = impulse_segment.resample(self.sample_rate) impulse_segment.resample(self.sample_rate)
if self.sample_rate != impulse_segment.sample_rate: if self.sample_rate != impulse_segment.sample_rate:
raise ValueError("Impulse segment's sample rate (%d Hz) is not" raise ValueError("Impulse segment's sample rate (%d Hz) is not "
"equal to base signal sample rate (%d Hz)." % "equal to base signal sample rate (%d Hz)." %
(impulse_segment.sample_rate, self.sample_rate)) (impulse_segment.sample_rate, self.sample_rate))
samples = signal.fftconvolve(self.samples, impulse_segment.samples, samples = signal.fftconvolve(self.samples, impulse_segment.samples,
......
...@@ -9,6 +9,7 @@ from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor ...@@ -9,6 +9,7 @@ from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor
from data_utils.augmentor.resample import ResampleAugmentor from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import \ from data_utils.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor OnlineBayesianNormalizationAugmentor
...@@ -24,21 +25,46 @@ class AugmentationPipeline(object): ...@@ -24,21 +25,46 @@ class AugmentationPipeline(object):
string, e.g. string, e.g.
.. code-block:: .. code-block::
'[{"type": "volume",
"params": {"min_gain_dBFS": -15,
"max_gain_dBFS": 15},
"prob": 0.5},
{"type": "speed",
"params": {"min_speed_rate": 0.8,
"max_speed_rate": 1.2},
"prob": 0.5}
]'
[ {
"type": "noise",
"params": {"min_snr_dB": 10,
"max_snr_dB": 20,
"noise_manifest": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 1.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
This augmentation configuration inserts two augmentation models This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
:param augmentation_config: Augmentation configuration in json string. :param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str :type augmentation_config: str
...@@ -61,7 +87,7 @@ class AugmentationPipeline(object): ...@@ -61,7 +87,7 @@ class AugmentationPipeline(object):
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
for augmentor, rate in zip(self._augmentors, self._rates): for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) <= rate: if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment) augmentor.transform_audio(audio_segment)
def _parse_pipeline_from(self, config_json): def _parse_pipeline_from(self, config_json):
...@@ -92,5 +118,7 @@ class AugmentationPipeline(object): ...@@ -92,5 +118,7 @@ class AugmentationPipeline(object):
return OnlineBayesianNormalizationAugmentor(self._rng, **params) return OnlineBayesianNormalizationAugmentor(self._rng, **params)
elif augmentor_type == "noise": elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params) return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
else: else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
"""Contains the impulse response augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
from data_utils import utils
from data_utils.audio import AudioSegment
class ImpulseResponseAugmentor(AugmentorBase):
"""Augmentation model for adding impulse response effect.
:param rng: Random generator object.
:type rng: random.Random
:param impulse_manifest: Manifest path for impulse audio data.
:type impulse_manifest: basestring
"""
def __init__(self, rng, impulse_manifest):
self._rng = rng
self._manifest = utils.read_manifest(manifest_path=impulse_manifest)
def transform_audio(self, audio_segment):
"""Add impulse response effect.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
noise_json = self._rng.sample(self._manifest, 1)[0]
noise_segment = AudioSegment.from_file(noise_json['audio_filepath'])
audio_segment.convolve(noise_segment, allow_resample=True)
...@@ -5,7 +5,7 @@ from __future__ import print_function ...@@ -5,7 +5,7 @@ from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase from data_utils.augmentor.base import AugmentorBase
from data_utils import utils from data_utils import utils
from data_utils.speech import SpeechSegment from data_utils.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase): class NoisePerturbAugmentor(AugmentorBase):
...@@ -17,6 +17,8 @@ class NoisePerturbAugmentor(AugmentorBase): ...@@ -17,6 +17,8 @@ class NoisePerturbAugmentor(AugmentorBase):
:type min_snr_dB: float :type min_snr_dB: float
:param max_snr_dB: Maximal signal noise ratio, in decibels. :param max_snr_dB: Maximal signal noise ratio, in decibels.
:type max_snr_dB: float :type max_snr_dB: float
:param noise_manifest: Manifest path for noise audio data.
:type noise_manifest: basestring
""" """
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest): def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest):
...@@ -40,8 +42,8 @@ class NoisePerturbAugmentor(AugmentorBase): ...@@ -40,8 +42,8 @@ class NoisePerturbAugmentor(AugmentorBase):
diff_duration = noise_json['duration'] - audio_segment.duration diff_duration = noise_json['duration'] - audio_segment.duration
start = self._rng.uniform(0, diff_duration) start = self._rng.uniform(0, diff_duration)
end = start + audio_segment.duration end = start + audio_segment.duration
noise_segment = SpeechSegment.slice_from_file( noise_segment = AudioSegment.slice_from_file(
noise_json['audio_filepath'], transcript="", start=start, end=end) noise_json['audio_filepath'], start=start, end=end)
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB) snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
audio_segment.add_noise( audio_segment.add_noise(
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng) noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
...@@ -169,7 +169,7 @@ class DataGenerator(object): ...@@ -169,7 +169,7 @@ class DataGenerator(object):
manifest, batch_size, clipped=True) manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle": elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest) self._rng.shuffle(manifest)
elif not shuffle_method: elif shuffle_method == None:
pass pass
else: else:
raise ValueError("Unknown shuffle method %s." % raise ValueError("Unknown shuffle method %s." %
......
...@@ -115,7 +115,7 @@ class SpeechSegment(AudioSegment): ...@@ -115,7 +115,7 @@ class SpeechSegment(AudioSegment):
speech file. speech file.
:rtype: SpeechSegment :rtype: SpeechSegment
""" """
audio = Audiosegment.slice_from_file(filepath, start, end) audio = AudioSegment.slice_from_file(filepath, start, end)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript)
@classmethod @classmethod
......
...@@ -123,9 +123,7 @@ parser.add_argument( ...@@ -123,9 +123,7 @@ parser.add_argument(
help="Directory for saving models. (default: %(default)s)") help="Directory for saving models. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--augmentation_config", "--augmentation_config",
default='[{"type": "shift", ' default=open('augmentation.config', 'r').read(),
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]',
type=str, type=str,
help="Augmentation configuration in json-format. " help="Augmentation configuration in json-format. "
"(default: %(default)s)") "(default: %(default)s)")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册