提交 c5995bb7 编写于 作者: X xushaoyong

add 3 augmentor and unittest

...@@ -51,13 +51,13 @@ python compute_mean_std.py --help ...@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training: For GPU Training:
``` ```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
``` ```
For CPU Training: For CPU Training:
``` ```
python train.py --trainer_count 8 --use_gpu False python train.py --use_gpu False
``` ```
More help for arguments: More help for arguments:
......
...@@ -66,6 +66,54 @@ class AudioSegment(object): ...@@ -66,6 +66,54 @@ class AudioSegment(object):
samples, sample_rate = soundfile.read(file, dtype='float32') samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod @classmethod
def from_bytes(cls, bytes): def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples. """Create audio segment from a byte string containing audio samples.
...@@ -105,6 +153,20 @@ class AudioSegment(object): ...@@ -105,6 +153,20 @@ class AudioSegment(object):
samples = np.concatenate([seg.samples for seg in segments]) samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate) return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'): def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file. """Save audio segment to disk as wav file.
...@@ -130,68 +192,6 @@ class AudioSegment(object): ...@@ -130,68 +192,6 @@ class AudioSegment(object):
format='WAV', format='WAV',
subtype=subtype_map[dtype]) subtype=subtype_map[dtype])
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = 0. if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def superimpose(self, other): def superimpose(self, other):
"""Add samples from another segment to those of this segment """Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation). (sample-wise addition, not segment concatenation).
...@@ -225,7 +225,7 @@ class AudioSegment(object): ...@@ -225,7 +225,7 @@ class AudioSegment(object):
samples = self._convert_samples_from_float32(self._samples, dtype) samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring() return samples.tostring()
def apply_gain(self, gain): def gain_db(self, gain):
"""Apply gain in decibels to samples. """Apply gain in decibels to samples.
Note that this is an in-place transformation. Note that this is an in-place transformation.
...@@ -278,7 +278,7 @@ class AudioSegment(object): ...@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the " "Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" % "the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db)) (target_db, max_gain_db))
self.apply_gain(min(max_gain_db, target_db - self.rms_db)) self.gain_db(min(max_gain_db, target_db - self.rms_db))
def normalize_online_bayesian(self, def normalize_online_bayesian(self,
target_db, target_db,
...@@ -319,7 +319,7 @@ class AudioSegment(object): ...@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db = 10 * np.log10(mean_squared_estimate) rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain. # Compute required time-varying gain.
gain_db = target_db - rms_estimate_db gain_db = target_db - rms_estimate_db
self.apply_gain(gain_db) self.gain_db(gain_db)
def resample(self, target_sample_rate, filter='kaiser_best'): def resample(self, target_sample_rate, filter='kaiser_best'):
"""Resample the audio to a target sample rate. """Resample the audio to a target sample rate.
...@@ -329,9 +329,10 @@ class AudioSegment(object): ...@@ -329,9 +329,10 @@ class AudioSegment(object):
:param target_sample_rate: Target sample rate. :param target_sample_rate: Target sample rate.
:type target_sample_rate: int :type target_sample_rate: int
:param filter: The resampling filter to use one of {'kaiser_best', :param filter: The resampling filter to use one of {'kaiser_best',
'kaiser_fast'}. 'kaiser_fast'}.
:type filter: str :type filter: str
""" """
resample_ratio = target_sample_rate / self._sample_rate
self._samples = resampy.resample( self._samples = resampy.resample(
self.samples, self.sample_rate, target_sample_rate, filter=filter) self.samples, self.sample_rate, target_sample_rate, filter=filter)
self._sample_rate = target_sample_rate self._sample_rate = target_sample_rate
...@@ -364,6 +365,31 @@ class AudioSegment(object): ...@@ -364,6 +365,31 @@ class AudioSegment(object):
raise ValueError("Unknown value for the sides %s" % sides) raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples self._samples = padded._samples
def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None): def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries. """Cut the AudioSegment between given boundaries.
...@@ -503,7 +529,7 @@ class AudioSegment(object): ...@@ -503,7 +529,7 @@ class AudioSegment(object):
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise) noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng) noise_new.random_subsegment(self.duration, rng=rng)
noise_new.apply_gain(noise_gain_db) noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new) self.superimpose(noise_new)
@property @property
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import json import json
import random import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.resample import ResampleAugmentor from data_utils.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor
...@@ -79,11 +80,13 @@ class AugmentationPipeline(object): ...@@ -79,11 +80,13 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume": if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params) return VolumePerturbAugmentor(self._rng, **params)
if augmentor_type == "speed": elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params) return SpeedPerturbAugmentor(self._rng, **params)
if augmentor_type == "resample": elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params) return ResampleAugmentor(self._rng, **params)
if augmentor_type == "bayesian_normal": elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params) return OnlineBayesianNormalizationAugmentor(self._rng, **params)
else: else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
...@@ -30,4 +30,4 @@ class ResampleAugmentor(AugmentorBase): ...@@ -30,4 +30,4 @@ class ResampleAugmentor(AugmentorBase):
:param audio: Audio segment to add effects to. :param audio: Audio segment to add effects to.
:type audio: AudioSegment|SpeechSegment :type audio: AudioSegment|SpeechSegment
""" """
audio_segment.resample(self._new_sample_rate) audio_segment.resample(self._new_sample_rate)
\ No newline at end of file
"""Contains the volume perturb augmentation model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
class ShiftPerturbAugmentor(AugmentorBase):
"""Augmentation model for adding random shift perturbation.
:param rng: Random generator object.
:type rng: random.Random
:param min_shift_ms: Minimal shift in milliseconds.
:type min_shift_ms: float
:param max_shift_ms: Maximal shift in milliseconds.
:type max_shift_ms: float
"""
def __init__(self, rng, min_shift_ms, max_shift_ms):
self._min_shift_ms = min_shift_ms
self._max_shift_ms = max_shift_ms
self._rng = rng
def transform_audio(self, audio_segment):
"""Shift audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
audio_segment.shift(shift_ms)
...@@ -14,20 +14,21 @@ class SpeedPerturbAugmentor(AugmentorBase): ...@@ -14,20 +14,21 @@ class SpeedPerturbAugmentor(AugmentorBase):
:param rng: Random generator object. :param rng: Random generator object.
:type rng: random.Random :type rng: random.Random
:param min_speed_rate: Lower bound of new speed rate to sample. :param min_speed_rate: Lower bound of new speed rate to sample and should
not below 0.9.
:type min_speed_rate: float :type min_speed_rate: float
:param max_speed_rate: Upper bound of new speed rate to sample. :param max_speed_rate: Upper bound of new speed rate to sample and should
not above 1.1.
:type max_speed_rate: float :type max_speed_rate: float
""" """
def __init__(self, rng, min_speed_rate, max_speed_rate): def __init__(self, rng, min_speed_rate, max_speed_rate):
if min_speed_rate < 0.9:
if (min_speed_rate < 0.5): raise ValueError(
raise ValueError("Sampling speed below 0.9 can cause unnatural "\ "Sampling speed below 0.9 can cause unnatural effects")
"effects") if max_speed_rate > 1.1:
if (max_speed_rate > 1.5): raise ValueError(
raise ValueError("Sampling speed above 1.1 can cause unnatural "\ "Sampling speed above 1.1 can cause unnatural effects")
"effects")
self._min_speed_rate = min_speed_rate self._min_speed_rate = min_speed_rate
self._max_speed_rate = max_speed_rate self._max_speed_rate = max_speed_rate
self._rng = rng self._rng = rng
......
...@@ -37,4 +37,4 @@ class VolumePerturbAugmentor(AugmentorBase): ...@@ -37,4 +37,4 @@ class VolumePerturbAugmentor(AugmentorBase):
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.apply_gain(gain) audio_segment.gain_db(gain)
...@@ -7,6 +7,7 @@ from __future__ import print_function ...@@ -7,6 +7,7 @@ from __future__ import print_function
import random import random
import numpy as np import numpy as np
import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils import utils from data_utils import utils
from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.augmentor.augmentation import AugmentationPipeline
...@@ -44,6 +45,11 @@ class DataGenerator(object): ...@@ -44,6 +45,11 @@ class DataGenerator(object):
:types max_freq: None|float :types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'. :param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str :type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed. :param random_seed: Random seed.
:type random_seed: int :type random_seed: int
""" """
...@@ -58,6 +64,8 @@ class DataGenerator(object): ...@@ -58,6 +64,8 @@ class DataGenerator(object):
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count(),
random_seed=0): random_seed=0):
self._max_duration = max_duration self._max_duration = max_duration
self._min_duration = min_duration self._min_duration = min_duration
...@@ -69,7 +77,9 @@ class DataGenerator(object): ...@@ -69,7 +77,9 @@ class DataGenerator(object):
specgram_type=specgram_type, specgram_type=specgram_type,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq) max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._epoch = 0 self._epoch = 0
...@@ -207,10 +217,14 @@ class DataGenerator(object): ...@@ -207,10 +217,14 @@ class DataGenerator(object):
def reader(): def reader():
for instance in manifest: for instance in manifest:
yield self._process_utterance(instance["audio_filepath"], yield instance
instance["text"])
return reader def mapper(instance):
return self._process_utterance(instance["audio_filepath"],
instance["text"])
return paddle.reader.xmap_readers(
mapper, reader, self._num_threads, 1024, order=True)
def _padding_batch(self, batch, padding_to=-1, flatten=False): def _padding_batch(self, batch, padding_to=-1, flatten=False):
""" """
......
...@@ -24,26 +24,64 @@ class AudioFeaturizer(object): ...@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
""" """
def __init__(self, def __init__(self,
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None): max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type self._specgram_type = specgram_type
self._stride_ms = stride_ms self._stride_ms = stride_ms
self._window_ms = window_ms self._window_ms = window_ms
self._max_freq = max_freq self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self, audio_segment): def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsamplling=True):
"""Extract audio features from AudioSegment or SpeechSegment. """Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from. :param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment :type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray. :return: Spectrogram audio feature in 2darray.
:rtype: ndarray :rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
""" """
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment.samples, return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate) audio_segment.sample_rate)
......
...@@ -29,6 +29,15 @@ class SpeechFeaturizer(object): ...@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
""" """
def __init__(self, def __init__(self,
...@@ -36,9 +45,18 @@ class SpeechFeaturizer(object): ...@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None): max_freq=None,
self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, target_sample_rate=16000,
window_ms, max_freq) use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath) self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment): def featurize(self, speech_segment):
......
...@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment): ...@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment):
return cls(samples, sample_rate, transcripts) return cls(samples, sample_rate, transcripts)
@classmethod @classmethod
def slice_from_file(cls, filepath, start=None, end=None, transcript): def slice_from_file(cls, filepath, transcript, start=None, end=None):
"""Loads a small section of an speech without having to load """Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful. the entire file into the memory which can be incredibly wasteful.
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import argparse import argparse
import gzip import gzip
import distutils.util import distutils.util
import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
...@@ -38,6 +39,11 @@ parser.add_argument( ...@@ -38,6 +39,11 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -50,7 +56,7 @@ parser.add_argument( ...@@ -50,7 +56,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='./params.tar.gz', default='checkpoints/params.latest.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -67,7 +73,8 @@ def infer(): ...@@ -67,7 +73,8 @@ def infer():
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}') augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config # create network config
# paddle.data_type.dense_array is used for variable batch input. # paddle.data_type.dense_array is used for variable batch input.
......
"""Test augmentor class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from data_utils import audio
from data_utils.augmentor.augmentation import AugmentationPipeline
import random
import numpy as np
random_seed = 0
#audio instance
audio_data = [3.0517571e-05, -8.54492188e-04, -1.09863281e-03, -9.4604492e-04,\
-1.31225586e-03, -1.09863281e-03, -1.73950195e-03, -2.1057189e-03,\
-2.04467773e-03, -1.46484375e-03, -1.43432617e-03, -9.4604492e-04,\
-1.95312500e-03, -1.86157227e-03, -2.10571289e-03, -2.3193354e-03,\
-2.01416016e-03, -2.62451172e-03, -2.07519531e-03, -2.3803719e-03]
audio_data = np.array(audio_data)
samplerate = 10
class TestAugmentor(unittest.TestCase):
def test_volume(self):
config_json = '[{"type": "volume","params": {"min_gain_dBFS": -15, '\
'"max_gain_dBFS": 15},"prob": 1.0}]'
aug_pipeline = AugmentationPipeline(
augmentation_config=config_json, random_seed=random_seed)
audio_seg = audio.AudioSegment(audio_data, samplerate)
aug_pipeline.transform_audio(audio_seg)
orig_audio = audio.AudioSegment(audio_data, samplerate)
self.assertFalse(np.any(audio_seg.samples == orig_audio.samples))
def test_speed(self):
config_json = '[{"type":"speed","params": {"min_speed_rate": 0.9,' \
'"max_speed_rate": 1.1},"prob": 1.0}]'
aug_pipeline = AugmentationPipeline(
augmentation_config=config_json, random_seed=random_seed)
audio_seg = audio.AudioSegment(audio_data, samplerate)
aug_pipeline.transform_audio(audio_seg)
orig_audio = audio.AudioSegment(audio_data, samplerate)
self.assertFalse(np.any(audio_seg.samples == orig_audio.samples))
def test_resample(self):
config_json = '[{"type":"resample","params": {"new_sample_rate":5},'\
'"prob": 1.0}]'
aug_pipeline = AugmentationPipeline(
augmentation_config=config_json, random_seed=random_seed)
audio_seg = audio.AudioSegment(audio_data, samplerate)
aug_pipeline.transform_audio(audio_seg)
self.assertTrue(audio_seg.sample_rate == 5)
def test_bayesial(self):
config_json = '[{"type":"bayesian_normal","params":{"target_db":-20,' \
'"prior_db":-4, "prior_samples": -8, "startup_delay": 0.0},"prob":1.0}]'
aug_pipeline = AugmentationPipeline(
augmentation_config=config_json, random_seed=random_seed)
audio_seg = audio.AudioSegment(audio_data, samplerate)
aug_pipeline.transform_audio(audio_seg)
orig_audio = audio.AudioSegment(audio_data, samplerate)
self.assertFalse(np.any(audio_seg.samples == orig_audio.samples))
if __name__ == '__main__':
unittest.main()
...@@ -9,6 +9,7 @@ import argparse ...@@ -9,6 +9,7 @@ import argparse
import gzip import gzip
import time import time
import distutils.util import distutils.util
import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from model import deep_speech2 from model import deep_speech2
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
...@@ -16,10 +17,10 @@ import utils ...@@ -16,10 +17,10 @@ import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.") "--batch_size", default=256, type=int, help="Minibatch size.")
parser.add_argument( parser.add_argument(
"--num_passes", "--num_passes",
default=20, default=200,
type=int, type=int,
help="Training pass number. (default: %(default)s)") help="Training pass number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -52,17 +53,34 @@ parser.add_argument( ...@@ -52,17 +53,34 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--max_duration",
default=27.0,
type=float,
help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)")
parser.add_argument(
"--min_duration",
default=0.0,
type=float,
help="Audios with duration smaller than this will be discarded. "
"(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--shuffle_method", "--shuffle_method",
default='instance_shuffle', default='batch_shuffle_clipped',
type=str, type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)") "'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--trainer_count", "--trainer_count",
default=4, default=8,
type=int, type=int,
help="Trainer number. (default: %(default)s)") help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -92,7 +110,9 @@ parser.add_argument( ...@@ -92,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)") "the existing model of this path. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--augmentation_config", "--augmentation_config",
default='{}', default='[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]',
type=str, type=str,
help="Augmentation configuration in json-format. " help="Augmentation configuration in json-format. "
"(default: %(default)s)") "(default: %(default)s)")
...@@ -107,7 +127,10 @@ def train(): ...@@ -107,7 +127,10 @@ def train():
return DataGenerator( return DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config=args.augmentation_config) augmentation_config=args.augmentation_config,
max_duration=args.max_duration,
min_duration=args.min_duration,
num_threads=args.num_threads_data)
train_generator = data_generator() train_generator = data_generator()
test_generator = data_generator() test_generator = data_generator()
...@@ -168,7 +191,7 @@ def train(): ...@@ -168,7 +191,7 @@ def train():
print("\nPass: %d, Batch: %d, TrainCost: %f" % ( print("\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f: with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
...@@ -181,6 +204,9 @@ def train(): ...@@ -181,6 +204,9 @@ def train():
reader=test_batch_reader, feeding=test_generator.feeding) reader=test_batch_reader, feeding=test_generator.feeding)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost)) (time.time() - start_time, event.pass_id, result.cost))
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
parameters.to_tar(f)
# run train # run train
trainer.train( trainer.train(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册