diff --git a/README.md b/README.md index 0cdb203d21ef5fa854a011f2f0381078cabcb874..2912ff3143516ee21f21732f25992fadcd33c270 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,13 @@ python compute_mean_std.py --help For GPU Training: ``` -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py ``` For CPU Training: ``` -python train.py --trainer_count 8 --use_gpu False +python train.py --use_gpu False ``` More help for arguments: diff --git a/data_utils/audio.py b/data_utils/audio.py old mode 100755 new mode 100644 index f80425eac76ded67d8501220aa6579e7b3794867..3d9b6c119278e3cf5241f3affeead0a88a735b61 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -66,6 +66,54 @@ class AudioSegment(object): samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) + @classmethod + def slice_from_file(cls, file, start=None, end=None): + """Loads a small section of an audio without having to load + the entire file into the memory which can be incredibly wasteful. + + :param file: Input audio filepath or file object. + :type file: basestring|file + :param start: Start time in seconds. If start is negative, it wraps + around from the end. If not provided, this function + reads from the very beginning. + :type start: float + :param end: End time in seconds. If end is negative, it wraps around + from the end. If not provided, the default behvaior is + to read to the end of the file. + :type end: float + :return: AudioSegment instance of the specified slice of the input + audio file. + :rtype: AudioSegment + :raise ValueError: If start or end is incorrectly set, e.g. out of + bounds in time. + """ + sndfile = soundfile.SoundFile(file) + sample_rate = sndfile.samplerate + duration = float(len(sndfile)) / sample_rate + start = 0. if start is None else start + end = 0. if end is None else end + if start < 0.0: + start += duration + if end < 0.0: + end += duration + if start < 0.0: + raise ValueError("The slice start position (%f s) is out of " + "bounds." % start) + if end < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % + end) + if start > end: + raise ValueError("The slice start position (%f s) is later than " + "the slice end position (%f s)." % (start, end)) + if end > duration: + raise ValueError("The slice end position (%f s) is out of bounds " + "(> %f s)" % (end, duration)) + start_frame = int(start * sample_rate) + end_frame = int(end * sample_rate) + sndfile.seek(start_frame) + data = sndfile.read(frames=end_frame - start_frame, dtype='float32') + return cls(data, sample_rate) + @classmethod def from_bytes(cls, bytes): """Create audio segment from a byte string containing audio samples. @@ -105,6 +153,20 @@ class AudioSegment(object): samples = np.concatenate([seg.samples for seg in segments]) return cls(samples, sample_rate) + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent audio segment of the given duration and sample rate. + + :param duration: Length of silence in seconds. + :type duration: float + :param sample_rate: Sample rate. + :type sample_rate: float + :return: Silent AudioSegment instance of the given duration. + :rtype: AudioSegment + """ + samples = np.zeros(int(duration * sample_rate)) + return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): """Save audio segment to disk as wav file. @@ -130,68 +192,6 @@ class AudioSegment(object): format='WAV', subtype=subtype_map[dtype]) - @classmethod - def slice_from_file(cls, file, start=None, end=None): - """Loads a small section of an audio without having to load - the entire file into the memory which can be incredibly wasteful. - - :param file: Input audio filepath or file object. - :type file: basestring|file - :param start: Start time in seconds. If start is negative, it wraps - around from the end. If not provided, this function - reads from the very beginning. - :type start: float - :param end: End time in seconds. If end is negative, it wraps around - from the end. If not provided, the default behvaior is - to read to the end of the file. - :type end: float - :return: AudioSegment instance of the specified slice of the input - audio file. - :rtype: AudioSegment - :raise ValueError: If start or end is incorrectly set, e.g. out of - bounds in time. - """ - sndfile = soundfile.SoundFile(file) - sample_rate = sndfile.samplerate - duration = float(len(sndfile)) / sample_rate - start = 0. if start is None else start - end = 0. if end is None else end - if start < 0.0: - start += duration - if end < 0.0: - end += duration - if start < 0.0: - raise ValueError("The slice start position (%f s) is out of " - "bounds." % start) - if end < 0.0: - raise ValueError("The slice end position (%f s) is out of bounds." % - end) - if start > end: - raise ValueError("The slice start position (%f s) is later than " - "the slice end position (%f s)." % (start, end)) - if end > duration: - raise ValueError("The slice end position (%f s) is out of bounds " - "(> %f s)" % (end, duration)) - start_frame = int(start * sample_rate) - end_frame = int(end * sample_rate) - sndfile.seek(start_frame) - data = sndfile.read(frames=end_frame - start_frame, dtype='float32') - return cls(data, sample_rate) - - @classmethod - def make_silence(cls, duration, sample_rate): - """Creates a silent audio segment of the given duration and sample rate. - - :param duration: Length of silence in seconds. - :type duration: float - :param sample_rate: Sample rate. - :type sample_rate: float - :return: Silent AudioSegment instance of the given duration. - :rtype: AudioSegment - """ - samples = np.zeros(int(duration * sample_rate)) - return cls(samples, sample_rate) - def superimpose(self, other): """Add samples from another segment to those of this segment (sample-wise addition, not segment concatenation). @@ -225,7 +225,7 @@ class AudioSegment(object): samples = self._convert_samples_from_float32(self._samples, dtype) return samples.tostring() - def apply_gain(self, gain): + def gain_db(self, gain): """Apply gain in decibels to samples. Note that this is an in-place transformation. @@ -278,7 +278,7 @@ class AudioSegment(object): "Unable to normalize segment to %f dB because the " "the probable gain have exceeds max_gain_db (%f dB)" % (target_db, max_gain_db)) - self.apply_gain(min(max_gain_db, target_db - self.rms_db)) + self.gain_db(min(max_gain_db, target_db - self.rms_db)) def normalize_online_bayesian(self, target_db, @@ -319,7 +319,7 @@ class AudioSegment(object): rms_estimate_db = 10 * np.log10(mean_squared_estimate) # Compute required time-varying gain. gain_db = target_db - rms_estimate_db - self.apply_gain(gain_db) + self.gain_db(gain_db) def resample(self, target_sample_rate, filter='kaiser_best'): """Resample the audio to a target sample rate. @@ -329,9 +329,10 @@ class AudioSegment(object): :param target_sample_rate: Target sample rate. :type target_sample_rate: int :param filter: The resampling filter to use one of {'kaiser_best', - 'kaiser_fast'}. + 'kaiser_fast'}. :type filter: str """ + resample_ratio = target_sample_rate / self._sample_rate self._samples = resampy.resample( self.samples, self.sample_rate, target_sample_rate, filter=filter) self._sample_rate = target_sample_rate @@ -364,6 +365,31 @@ class AudioSegment(object): raise ValueError("Unknown value for the sides %s" % sides) self._samples = padded._samples + def shift(self, shift_ms): + """Shift the audio in time. If `shift_ms` is positive, shift with time + advance; if negative, shift with time delay. Silence are padded to + keep the duration unchanged. + + Note that this is an in-place transformation. + + :param shift_ms: Shift time in millseconds. If positive, shift with + time advance; if negative; shift with time delay. + :type shift_ms: float + :raises ValueError: If shift_ms is longer than audio duration. + """ + if abs(shift_ms) / 1000.0 > self.duration: + raise ValueError("Absolute value of shift_ms should be smaller " + "than audio duration.") + shift_samples = int(shift_ms * self._sample_rate / 1000) + if shift_samples > 0: + # time advance + self._samples[:-shift_samples] = self._samples[shift_samples:] + self._samples[-shift_samples:] = 0 + elif shift_samples < 0: + # time delay + self._samples[-shift_samples:] = self._samples[:shift_samples] + self._samples[:-shift_samples] = 0 + def subsegment(self, start_sec=None, end_sec=None): """Cut the AudioSegment between given boundaries. @@ -503,7 +529,7 @@ class AudioSegment(object): noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) noise_new = copy.deepcopy(noise) noise_new.random_subsegment(self.duration, rng=rng) - noise_new.apply_gain(noise_gain_db) + noise_new.gain_db(noise_gain_db) self.superimpose(noise_new) @property diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py old mode 100755 new mode 100644 index 087880086cae185e00a9979c2a22409590b9dbcb..f8fd214a0797e25174a5dca9ba6186ae2355d3ef --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -6,6 +6,7 @@ from __future__ import print_function import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor from data_utils.augmentor.resample import ResampleAugmentor from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor @@ -79,11 +80,13 @@ class AugmentationPipeline(object): """Return an augmentation model by the type name, and pass in params.""" if augmentor_type == "volume": return VolumePerturbAugmentor(self._rng, **params) - if augmentor_type == "speed": + elif augmentor_type == "shift": + return ShiftPerturbAugmentor(self._rng, **params) + elif augmentor_type == "speed": return SpeedPerturbAugmentor(self._rng, **params) - if augmentor_type == "resample": + elif augmentor_type == "resample": return ResampleAugmentor(self._rng, **params) - if augmentor_type == "bayesian_normal": + elif augmentor_type == "bayesian_normal": return OnlineBayesianNormalizationAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/resample.py b/data_utils/augmentor/resample.py index 529b5fec15220da7aed730e82d1003749b4d74e0..8df17f3a869420bca1e4e6c0ae9b4035f7d50d8d 100755 --- a/data_utils/augmentor/resample.py +++ b/data_utils/augmentor/resample.py @@ -30,4 +30,4 @@ class ResampleAugmentor(AugmentorBase): :param audio: Audio segment to add effects to. :type audio: AudioSegment|SpeechSegment """ - audio_segment.resample(self._new_sample_rate) \ No newline at end of file + audio_segment.resample(self._new_sample_rate) diff --git a/data_utils/augmentor/shift_perturb.py b/data_utils/augmentor/shift_perturb.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cbe3e172f6b291f3b778b748affda0341a3181 --- /dev/null +++ b/data_utils/augmentor/shift_perturb.py @@ -0,0 +1,34 @@ +"""Contains the volume perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class ShiftPerturbAugmentor(AugmentorBase): + """Augmentation model for adding random shift perturbation. + + :param rng: Random generator object. + :type rng: random.Random + :param min_shift_ms: Minimal shift in milliseconds. + :type min_shift_ms: float + :param max_shift_ms: Maximal shift in milliseconds. + :type max_shift_ms: float + """ + + def __init__(self, rng, min_shift_ms, max_shift_ms): + self._min_shift_ms = min_shift_ms + self._max_shift_ms = max_shift_ms + self._rng = rng + + def transform_audio(self, audio_segment): + """Shift audio. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms) + audio_segment.shift(shift_ms) diff --git a/data_utils/augmentor/speed_perturb.py b/data_utils/augmentor/speed_perturb.py old mode 100755 new mode 100644 index 3f880fbba2a73223d85841e79231b4e73c789d0e..8c6c8b63cdb0eb2aee14383e4c47f94191d43b8e --- a/data_utils/augmentor/speed_perturb.py +++ b/data_utils/augmentor/speed_perturb.py @@ -14,20 +14,21 @@ class SpeedPerturbAugmentor(AugmentorBase): :param rng: Random generator object. :type rng: random.Random - :param min_speed_rate: Lower bound of new speed rate to sample. + :param min_speed_rate: Lower bound of new speed rate to sample and should + not below 0.9. :type min_speed_rate: float - :param max_speed_rate: Upper bound of new speed rate to sample. + :param max_speed_rate: Upper bound of new speed rate to sample and should + not above 1.1. :type max_speed_rate: float """ def __init__(self, rng, min_speed_rate, max_speed_rate): - - if (min_speed_rate < 0.5): - raise ValueError("Sampling speed below 0.9 can cause unnatural "\ - "effects") - if (max_speed_rate > 1.5): - raise ValueError("Sampling speed above 1.1 can cause unnatural "\ - "effects") + if min_speed_rate < 0.9: + raise ValueError( + "Sampling speed below 0.9 can cause unnatural effects") + if max_speed_rate > 1.1: + raise ValueError( + "Sampling speed above 1.1 can cause unnatural effects") self._min_speed_rate = min_speed_rate self._max_speed_rate = max_speed_rate self._rng = rng diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py old mode 100755 new mode 100644 index 62631fb041c45350811b2cd2dd78d6758a622db8..758676d558d8e4d77191504d0d7b75cefe020549 --- a/data_utils/augmentor/volume_perturb.py +++ b/data_utils/augmentor/volume_perturb.py @@ -37,4 +37,4 @@ class VolumePerturbAugmentor(AugmentorBase): :type audio_segment: AudioSegmenet|SpeechSegment """ gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) - audio_segment.apply_gain(gain) + audio_segment.gain_db(gain) diff --git a/data_utils/data.py b/data_utils/data.py index 424343a48ffa579a8ab465794987f957de36abdb..d01ca8cc7a9c08bcbe615e7ea2800751193d1a6e 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -7,6 +7,7 @@ from __future__ import print_function import random import numpy as np +import multiprocessing import paddle.v2 as paddle from data_utils import utils from data_utils.augmentor.augmentation import AugmentationPipeline @@ -44,6 +45,11 @@ class DataGenerator(object): :types max_freq: None|float :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str + :param use_dB_normalization: Whether to normalize the audio to -20 dB + before extracting the features. + :type use_dB_normalization: bool + :param num_threads: Number of CPU threads for processing data. + :type num_threads: int :param random_seed: Random seed. :type random_seed: int """ @@ -58,6 +64,8 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', + use_dB_normalization=True, + num_threads=multiprocessing.cpu_count(), random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -69,7 +77,9 @@ class DataGenerator(object): specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, - max_freq=max_freq) + max_freq=max_freq, + use_dB_normalization=use_dB_normalization) + self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 @@ -207,10 +217,14 @@ class DataGenerator(object): def reader(): for instance in manifest: - yield self._process_utterance(instance["audio_filepath"], - instance["text"]) + yield instance - return reader + def mapper(instance): + return self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return paddle.reader.xmap_readers( + mapper, reader, self._num_threads, 1024, order=True) def _padding_batch(self, batch, padding_to=-1, flatten=False): """ diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 9f9d4e505d13b4fcaf1c1411821163caa4b73bc8..4b4d02c60f4193d753badae1aaa3b17ab3b7ea43 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -24,26 +24,64 @@ class AudioFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Audio are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): self._specgram_type = specgram_type self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq + self._target_sample_rate = target_sample_rate + self._use_dB_normalization = use_dB_normalization + self._target_dB = target_dB - def featurize(self, audio_segment): + def featurize(self, + audio_segment, + allow_downsampling=True, + allow_upsamplling=True): """Extract audio features from AudioSegment or SpeechSegment. :param audio_segment: Audio/speech segment to extract features from. :type audio_segment: AudioSegment|SpeechSegment + :param allow_downsampling: Whether to allow audio downsampling before + featurizing. + :type allow_downsampling: bool + :param allow_upsampling: Whether to allow audio upsampling before + featurizing. + :type allow_upsampling: bool :return: Spectrogram audio feature in 2darray. :rtype: ndarray + :raises ValueError: If audio sample rate is not supported. """ + # upsampling or downsampling + if ((audio_segment.sample_rate > self._target_sample_rate and + allow_downsampling) or + (audio_segment.sample_rate < self._target_sample_rate and + allow_upsampling)): + audio_segment.resample(self._target_sample_rate) + if audio_segment.sample_rate != self._target_sample_rate: + raise ValueError("Audio sample rate is not supported. " + "Turn allow_downsampling or allow up_sampling on.") + # decibel normalization + if self._use_dB_normalization: + audio_segment.normalize(target_db=self._target_dB) + # extract spectrogram return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 7702045597fb8379bffee2c31029ace4b2453f92..26283892e85beb8b41351fb2d1b876c6284da887 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -29,6 +29,15 @@ class SpeechFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Speech are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, @@ -36,9 +45,18 @@ class SpeechFeaturizer(object): specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): - self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, - window_ms, max_freq) + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + self._audio_featurizer = AudioFeaturizer( + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) self._text_featurizer = TextFeaturizer(vocab_filepath) def featurize(self, speech_segment): diff --git a/data_utils/speech.py b/data_utils/speech.py index fc031ff46f4a3820e3b13f7804c91b33948712d1..568e4443ba557149505dfb4de6f230b4962e332a 100644 --- a/data_utils/speech.py +++ b/data_utils/speech.py @@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment): return cls(samples, sample_rate, transcripts) @classmethod - def slice_from_file(cls, filepath, start=None, end=None, transcript): + def slice_from_file(cls, filepath, transcript, start=None, end=None): """Loads a small section of an speech without having to load the entire file into the memory which can be incredibly wasteful. diff --git a/infer.py b/infer.py index 06449ab05c7960ec78acc9ce5bb664cf1058a845..9037a108e2c5cbf8f5d8544b6fa07057067c9340 100644 --- a/infer.py +++ b/infer.py @@ -6,6 +6,7 @@ from __future__ import print_function import argparse import gzip import distutils.util +import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 @@ -38,6 +39,11 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -50,7 +56,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -67,7 +73,8 @@ def infer(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. diff --git a/tests/test_augmentor.py b/tests/test_augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1f5439c6a419728a845f21b42be7e955491aa6 --- /dev/null +++ b/tests/test_augmentor.py @@ -0,0 +1,65 @@ +"""Test augmentor class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from data_utils import audio +from data_utils.augmentor.augmentation import AugmentationPipeline +import random +import numpy as np + +random_seed = 0 +#audio instance +audio_data = [3.0517571e-05, -8.54492188e-04, -1.09863281e-03, -9.4604492e-04,\ + -1.31225586e-03, -1.09863281e-03, -1.73950195e-03, -2.1057189e-03,\ + -2.04467773e-03, -1.46484375e-03, -1.43432617e-03, -9.4604492e-04,\ + -1.95312500e-03, -1.86157227e-03, -2.10571289e-03, -2.3193354e-03,\ + -2.01416016e-03, -2.62451172e-03, -2.07519531e-03, -2.3803719e-03] +audio_data = np.array(audio_data) +samplerate = 10 + + +class TestAugmentor(unittest.TestCase): + def test_volume(self): + config_json = '[{"type": "volume","params": {"min_gain_dBFS": -15, '\ + '"max_gain_dBFS": 15},"prob": 1.0}]' + aug_pipeline = AugmentationPipeline( + augmentation_config=config_json, random_seed=random_seed) + audio_seg = audio.AudioSegment(audio_data, samplerate) + aug_pipeline.transform_audio(audio_seg) + orig_audio = audio.AudioSegment(audio_data, samplerate) + self.assertFalse(np.any(audio_seg.samples == orig_audio.samples)) + + def test_speed(self): + config_json = '[{"type":"speed","params": {"min_speed_rate": 0.9,' \ + '"max_speed_rate": 1.1},"prob": 1.0}]' + aug_pipeline = AugmentationPipeline( + augmentation_config=config_json, random_seed=random_seed) + audio_seg = audio.AudioSegment(audio_data, samplerate) + aug_pipeline.transform_audio(audio_seg) + orig_audio = audio.AudioSegment(audio_data, samplerate) + self.assertFalse(np.any(audio_seg.samples == orig_audio.samples)) + + def test_resample(self): + config_json = '[{"type":"resample","params": {"new_sample_rate":5},'\ + '"prob": 1.0}]' + aug_pipeline = AugmentationPipeline( + augmentation_config=config_json, random_seed=random_seed) + audio_seg = audio.AudioSegment(audio_data, samplerate) + aug_pipeline.transform_audio(audio_seg) + self.assertTrue(audio_seg.sample_rate == 5) + + def test_bayesial(self): + config_json = '[{"type":"bayesian_normal","params":{"target_db":-20,' \ + '"prior_db":-4, "prior_samples": -8, "startup_delay": 0.0},"prob":1.0}]' + aug_pipeline = AugmentationPipeline( + augmentation_config=config_json, random_seed=random_seed) + audio_seg = audio.AudioSegment(audio_data, samplerate) + aug_pipeline.transform_audio(audio_seg) + orig_audio = audio.AudioSegment(audio_data, samplerate) + self.assertFalse(np.any(audio_seg.samples == orig_audio.samples)) + + +if __name__ == '__main__': + unittest.main() diff --git a/train.py b/train.py index c60a039b69d91a89eb20e83ec1e090c8600d47a3..3a2d0cad9ec9635c7e44e0149e426842a5e892b6 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import argparse import gzip import time import distutils.util +import multiprocessing import paddle.v2 as paddle from model import deep_speech2 from data_utils.data import DataGenerator @@ -16,10 +17,10 @@ import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--batch_size", default=32, type=int, help="Minibatch size.") + "--batch_size", default=256, type=int, help="Minibatch size.") parser.add_argument( "--num_passes", - default=20, + default=200, type=int, help="Training pass number. (default: %(default)s)") parser.add_argument( @@ -52,17 +53,34 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--max_duration", + default=27.0, + type=float, + help="Audios with duration larger than this will be discarded. " + "(default: %(default)s)") +parser.add_argument( + "--min_duration", + default=0.0, + type=float, + help="Audios with duration smaller than this will be discarded. " + "(default: %(default)s)") parser.add_argument( "--shuffle_method", - default='instance_shuffle', + default='batch_shuffle_clipped', type=str, help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " "'batch_shuffle_batch'. (default: %(default)s)") parser.add_argument( "--trainer_count", - default=4, + default=8, type=int, help="Trainer number. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -92,7 +110,9 @@ parser.add_argument( "the existing model of this path. (default: %(default)s)") parser.add_argument( "--augmentation_config", - default='{}', + default='[{"type": "shift", ' + '"params": {"min_shift_ms": -5, "max_shift_ms": 5},' + '"prob": 1.0}]', type=str, help="Augmentation configuration in json-format. " "(default: %(default)s)") @@ -107,7 +127,10 @@ def train(): return DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config=args.augmentation_config) + augmentation_config=args.augmentation_config, + max_duration=args.max_duration, + min_duration=args.min_duration, + num_threads=args.num_threads_data) train_generator = data_generator() test_generator = data_generator() @@ -168,7 +191,7 @@ def train(): print("\nPass: %d, Batch: %d, TrainCost: %f" % ( event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params.tar.gz", 'w') as f: + with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.') @@ -181,6 +204,9 @@ def train(): reader=test_batch_reader, feeding=test_generator.feeding) print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % (time.time() - start_time, event.pass_id, result.cost)) + with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, + 'w') as f: + parameters.to_tar(f) # run train trainer.train(