diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 0cdb203d21ef5fa854a011f2f0381078cabcb874..2912ff3143516ee21f21732f25992fadcd33c270 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -51,13 +51,13 @@ python compute_mean_std.py --help For GPU Training: ``` -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py ``` For CPU Training: ``` -python train.py --trainer_count 8 --use_gpu False +python train.py --use_gpu False ``` More help for arguments: diff --git a/deep_speech_2/data_utils/audio.py b/deep_speech_2/data_utils/audio.py index 5d02feb60d66fc91d47fd1bed96a393ef8f76e1f..d55fae1efc951bf6025b2a6ba02852b1640fa10f 100644 --- a/deep_speech_2/data_utils/audio.py +++ b/deep_speech_2/data_utils/audio.py @@ -66,6 +66,54 @@ class AudioSegment(object): samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) + @classmethod + def slice_from_file(cls, file, start=None, end=None): + """Loads a small section of an audio without having to load + the entire file into the memory which can be incredibly wasteful. + + :param file: Input audio filepath or file object. + :type file: basestring|file + :param start: Start time in seconds. If start is negative, it wraps + around from the end. If not provided, this function + reads from the very beginning. + :type start: float + :param end: End time in seconds. If end is negative, it wraps around + from the end. If not provided, the default behvaior is + to read to the end of the file. + :type end: float + :return: AudioSegment instance of the specified slice of the input + audio file. + :rtype: AudioSegment + :raise ValueError: If start or end is incorrectly set, e.g. out of + bounds in time. + """ + sndfile = soundfile.SoundFile(file) + sample_rate = sndfile.samplerate + duration = float(len(sndfile)) / sample_rate + start = 0. if start is None else start + end = 0. if end is None else end + if start < 0.0: + start += duration + if end < 0.0: + end += duration + if start < 0.0: + raise ValueError("The slice start position (%f s) is out of " + "bounds." % start) + if end < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % + end) + if start > end: + raise ValueError("The slice start position (%f s) is later than " + "the slice end position (%f s)." % (start, end)) + if end > duration: + raise ValueError("The slice end position (%f s) is out of bounds " + "(> %f s)" % (end, duration)) + start_frame = int(start * sample_rate) + end_frame = int(end * sample_rate) + sndfile.seek(start_frame) + data = sndfile.read(frames=end_frame - start_frame, dtype='float32') + return cls(data, sample_rate) + @classmethod def from_bytes(cls, bytes): """Create audio segment from a byte string containing audio samples. @@ -105,6 +153,20 @@ class AudioSegment(object): samples = np.concatenate([seg.samples for seg in segments]) return cls(samples, sample_rate) + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent audio segment of the given duration and sample rate. + + :param duration: Length of silence in seconds. + :type duration: float + :param sample_rate: Sample rate. + :type sample_rate: float + :return: Silent AudioSegment instance of the given duration. + :rtype: AudioSegment + """ + samples = np.zeros(int(duration * sample_rate)) + return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): """Save audio segment to disk as wav file. @@ -130,68 +192,6 @@ class AudioSegment(object): format='WAV', subtype=subtype_map[dtype]) - @classmethod - def slice_from_file(cls, file, start=None, end=None): - """Loads a small section of an audio without having to load - the entire file into the memory which can be incredibly wasteful. - - :param file: Input audio filepath or file object. - :type file: basestring|file - :param start: Start time in seconds. If start is negative, it wraps - around from the end. If not provided, this function - reads from the very beginning. - :type start: float - :param end: End time in seconds. If end is negative, it wraps around - from the end. If not provided, the default behvaior is - to read to the end of the file. - :type end: float - :return: AudioSegment instance of the specified slice of the input - audio file. - :rtype: AudioSegment - :raise ValueError: If start or end is incorrectly set, e.g. out of - bounds in time. - """ - sndfile = soundfile.SoundFile(file) - sample_rate = sndfile.samplerate - duration = float(len(sndfile)) / sample_rate - start = 0. if start is None else start - end = 0. if end is None else end - if start < 0.0: - start += duration - if end < 0.0: - end += duration - if start < 0.0: - raise ValueError("The slice start position (%f s) is out of " - "bounds." % start) - if end < 0.0: - raise ValueError("The slice end position (%f s) is out of bounds." % - end) - if start > end: - raise ValueError("The slice start position (%f s) is later than " - "the slice end position (%f s)." % (start, end)) - if end > duration: - raise ValueError("The slice end position (%f s) is out of bounds " - "(> %f s)" % (end, duration)) - start_frame = int(start * sample_rate) - end_frame = int(end * sample_rate) - sndfile.seek(start_frame) - data = sndfile.read(frames=end_frame - start_frame, dtype='float32') - return cls(data, sample_rate) - - @classmethod - def make_silence(cls, duration, sample_rate): - """Creates a silent audio segment of the given duration and sample rate. - - :param duration: Length of silence in seconds. - :type duration: float - :param sample_rate: Sample rate. - :type sample_rate: float - :return: Silent AudioSegment instance of the given duration. - :rtype: AudioSegment - """ - samples = np.zeros(int(duration * sample_rate)) - return cls(samples, sample_rate) - def superimpose(self, other): """Add samples from another segment to those of this segment (sample-wise addition, not segment concatenation). @@ -225,7 +225,7 @@ class AudioSegment(object): samples = self._convert_samples_from_float32(self._samples, dtype) return samples.tostring() - def apply_gain(self, gain): + def gain_db(self, gain): """Apply gain in decibels to samples. Note that this is an in-place transformation. @@ -278,7 +278,7 @@ class AudioSegment(object): "Unable to normalize segment to %f dB because the " "the probable gain have exceeds max_gain_db (%f dB)" % (target_db, max_gain_db)) - self.apply_gain(min(max_gain_db, target_db - self.rms_db)) + self.gain_db(min(max_gain_db, target_db - self.rms_db)) def normalize_online_bayesian(self, target_db, @@ -319,7 +319,7 @@ class AudioSegment(object): rms_estimate_db = 10 * np.log10(mean_squared_estimate) # Compute required time-varying gain. gain_db = target_db - rms_estimate_db - self.apply_gain(gain_db) + self.gain_db(gain_db) def resample(self, target_sample_rate, quality='sinc_medium'): """Resample the audio to a target sample rate. @@ -366,6 +366,31 @@ class AudioSegment(object): raise ValueError("Unknown value for the sides %s" % sides) self._samples = padded._samples + def shift(self, shift_ms): + """Shift the audio in time. If `shift_ms` is positive, shift with time + advance; if negative, shift with time delay. Silence are padded to + keep the duration unchanged. + + Note that this is an in-place transformation. + + :param shift_ms: Shift time in millseconds. If positive, shift with + time advance; if negative; shift with time delay. + :type shift_ms: float + :raises ValueError: If shift_ms is longer than audio duration. + """ + if abs(shift_ms) / 1000.0 > self.duration: + raise ValueError("Absolute value of shift_ms should be smaller " + "than audio duration.") + shift_samples = int(shift_ms * self._sample_rate / 1000) + if shift_samples > 0: + # time advance + self._samples[:-shift_samples] = self._samples[shift_samples:] + self._samples[-shift_samples:] = 0 + elif shift_samples < 0: + # time delay + self._samples[-shift_samples:] = self._samples[:shift_samples] + self._samples[:-shift_samples] = 0 + def subsegment(self, start_sec=None, end_sec=None): """Cut the AudioSegment between given boundaries. @@ -505,7 +530,7 @@ class AudioSegment(object): noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) noise_new = copy.deepcopy(noise) noise_new.random_subsegment(self.duration, rng=rng) - noise_new.apply_gain(noise_gain_db) + noise_new.gain_db(noise_gain_db) self.superimpose(noise_new) @property diff --git a/deep_speech_2/data_utils/augmentor/augmentation.py b/deep_speech_2/data_utils/augmentor/augmentation.py index abe1a0ec89c5d6fc6f8ac1822df184cf5db4d7e1..0d60bbdb9cdd25b6df9177140576cb2bd6641fac 100644 --- a/deep_speech_2/data_utils/augmentor/augmentation.py +++ b/deep_speech_2/data_utils/augmentor/augmentation.py @@ -6,6 +6,7 @@ from __future__ import print_function import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor class AugmentationPipeline(object): @@ -76,5 +77,7 @@ class AugmentationPipeline(object): """Return an augmentation model by the type name, and pass in params.""" if augmentor_type == "volume": return VolumePerturbAugmentor(self._rng, **params) + elif augmentor_type == "shift": + return ShiftPerturbAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/deep_speech_2/data_utils/augmentor/volume_perturb.py b/deep_speech_2/data_utils/augmentor/volume_perturb.py index a5a9f6cadac13e651dd6902d68d0efdaa9a61dc4..62631fb041c45350811b2cd2dd78d6758a622db8 100644 --- a/deep_speech_2/data_utils/augmentor/volume_perturb.py +++ b/deep_speech_2/data_utils/augmentor/volume_perturb.py @@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegmenet|SpeechSegment """ - gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) audio_segment.apply_gain(gain) diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 44af7ffaa999c618a7dcd4884f528ef60e59eefe..d01ca8cc7a9c08bcbe615e7ea2800751193d1a6e 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -45,6 +45,9 @@ class DataGenerator(object): :types max_freq: None|float :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str + :param use_dB_normalization: Whether to normalize the audio to -20 dB + before extracting the features. + :type use_dB_normalization: bool :param num_threads: Number of CPU threads for processing data. :type num_threads: int :param random_seed: Random seed. @@ -61,6 +64,7 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', + use_dB_normalization=True, num_threads=multiprocessing.cpu_count(), random_seed=0): self._max_duration = max_duration @@ -73,7 +77,8 @@ class DataGenerator(object): specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, - max_freq=max_freq) + max_freq=max_freq, + use_dB_normalization=use_dB_normalization) self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 diff --git a/deep_speech_2/data_utils/featurizer/audio_featurizer.py b/deep_speech_2/data_utils/featurizer/audio_featurizer.py index 9f9d4e505d13b4fcaf1c1411821163caa4b73bc8..4b4d02c60f4193d753badae1aaa3b17ab3b7ea43 100644 --- a/deep_speech_2/data_utils/featurizer/audio_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/audio_featurizer.py @@ -24,26 +24,64 @@ class AudioFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Audio are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): self._specgram_type = specgram_type self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq + self._target_sample_rate = target_sample_rate + self._use_dB_normalization = use_dB_normalization + self._target_dB = target_dB - def featurize(self, audio_segment): + def featurize(self, + audio_segment, + allow_downsampling=True, + allow_upsamplling=True): """Extract audio features from AudioSegment or SpeechSegment. :param audio_segment: Audio/speech segment to extract features from. :type audio_segment: AudioSegment|SpeechSegment + :param allow_downsampling: Whether to allow audio downsampling before + featurizing. + :type allow_downsampling: bool + :param allow_upsampling: Whether to allow audio upsampling before + featurizing. + :type allow_upsampling: bool :return: Spectrogram audio feature in 2darray. :rtype: ndarray + :raises ValueError: If audio sample rate is not supported. """ + # upsampling or downsampling + if ((audio_segment.sample_rate > self._target_sample_rate and + allow_downsampling) or + (audio_segment.sample_rate < self._target_sample_rate and + allow_upsampling)): + audio_segment.resample(self._target_sample_rate) + if audio_segment.sample_rate != self._target_sample_rate: + raise ValueError("Audio sample rate is not supported. " + "Turn allow_downsampling or allow up_sampling on.") + # decibel normalization + if self._use_dB_normalization: + audio_segment.normalize(target_db=self._target_dB) + # extract spectrogram return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) diff --git a/deep_speech_2/data_utils/featurizer/speech_featurizer.py b/deep_speech_2/data_utils/featurizer/speech_featurizer.py index 7702045597fb8379bffee2c31029ace4b2453f92..26283892e85beb8b41351fb2d1b876c6284da887 100644 --- a/deep_speech_2/data_utils/featurizer/speech_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/speech_featurizer.py @@ -29,6 +29,15 @@ class SpeechFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Speech are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, @@ -36,9 +45,18 @@ class SpeechFeaturizer(object): specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): - self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, - window_ms, max_freq) + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + self._audio_featurizer = AudioFeaturizer( + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) self._text_featurizer = TextFeaturizer(vocab_filepath) def featurize(self, speech_segment): diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 71518133a347c459bbcf2670fa5d1dc226a619c8..9037a108e2c5cbf8f5d8544b6fa07057067c9340 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -56,7 +56,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( diff --git a/deep_speech_2/setup.sh b/deep_speech_2/setup.sh index 1ae2a5eee0f9cfd5b4318b29cf037165f78f2b73..cdec34ff07048a691f19658711a855ada40db9f0 100644 --- a/deep_speech_2/setup.sh +++ b/deep_speech_2/setup.sh @@ -27,4 +27,7 @@ if [ $? != 0 ]; then exit 1 fi +# prepare ./checkpoints +mkdir checkpoints + echo "Install all dependencies successfully." diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index fc23ec72692f319b556a75004a7508990df5357e..3a2d0cad9ec9635c7e44e0149e426842a5e892b6 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -17,10 +17,10 @@ import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--batch_size", default=32, type=int, help="Minibatch size.") + "--batch_size", default=256, type=int, help="Minibatch size.") parser.add_argument( "--num_passes", - default=20, + default=200, type=int, help="Training pass number. (default: %(default)s)") parser.add_argument( @@ -55,7 +55,7 @@ parser.add_argument( help="Use sortagrad or not. (default: %(default)s)") parser.add_argument( "--max_duration", - default=100.0, + default=27.0, type=float, help="Audios with duration larger than this will be discarded. " "(default: %(default)s)") @@ -67,13 +67,13 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--shuffle_method", - default='instance_shuffle', + default='batch_shuffle_clipped', type=str, help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " "'batch_shuffle_batch'. (default: %(default)s)") parser.add_argument( "--trainer_count", - default=4, + default=8, type=int, help="Trainer number. (default: %(default)s)") parser.add_argument( @@ -110,7 +110,9 @@ parser.add_argument( "the existing model of this path. (default: %(default)s)") parser.add_argument( "--augmentation_config", - default='{}', + default='[{"type": "shift", ' + '"params": {"min_shift_ms": -5, "max_shift_ms": 5},' + '"prob": 1.0}]', type=str, help="Augmentation configuration in json-format. " "(default: %(default)s)") @@ -189,7 +191,7 @@ def train(): print("\nPass: %d, Batch: %d, TrainCost: %f" % ( event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params.tar.gz", 'w') as f: + with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.') @@ -202,6 +204,9 @@ def train(): reader=test_batch_reader, feeding=test_generator.feeding) print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % (time.time() - start_time, event.pass_id, result.cost)) + with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, + 'w') as f: + parameters.to_tar(f) # run train trainer.train(