diff --git a/compute_mean_std.py b/compute_mean_std.py index b3015df73ee9e991a3e7268659451da3737abc98..9c301c93f6d2ce3ae099caa96830912f76ce6c58 100755 --- a/compute_mean_std.py +++ b/compute_mean_std.py @@ -1,3 +1,4 @@ +"""Compute mean and std for feature normalizer, and save to file.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -17,7 +18,7 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--num_samples", - default=500, + default=2000, type=int, help="Number of samples for computing mean and stddev. " "(default: %(default)s)") diff --git a/data_utils/audio.py b/data_utils/audio.py index 46b2412014387b74043969931b8f18e91a23e979..916c8ac1ae781bcb0ec6e1ed2ad1b3574dc6fe65 100755 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -1,3 +1,8 @@ +"""Contains the audio segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import numpy as np import io import soundfile @@ -5,64 +10,243 @@ import soundfile class AudioSegment(object): """Monaural audio segment abstraction. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :raises TypeError: If the sample data type is not float or int. """ def __init__(self, samples, sample_rate): - if not samples.dtype == np.float32: - raise ValueError("Sample data type of [%s] is not supported.") - self._samples = samples + """Create audio segment from samples. + + Samples are convert float32 internally, with int scaled to [-1, 1]. + """ + self._samples = self._convert_samples_to_float32(samples) self._sample_rate = sample_rate if self._samples.ndim >= 2: self._samples = np.mean(self._samples, 1) + def __eq__(self, other): + """Return whether two objects are equal.""" + if type(other) is not type(self): + return False + if self._sample_rate != other._sample_rate: + return False + if self._samples.shape != other._samples.shape: + return False + if np.any(self.samples != other._samples): + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + def __str__(self): + """Return human-readable representation of segment.""" + return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, " + "rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, + self.duration, self.rms_db)) + @classmethod - def from_file(cls, filepath): - samples, sample_rate = soundfile.read(filepath, dtype='float32') + def from_file(cls, file): + """Create audio segment from audio file. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :return: Audio segment instance. + :rtype: AudioSegment + """ + samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) @classmethod def from_bytes(cls, bytes): + """Create audio segment from a byte string containing audio samples. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :return: Audio segment instance. + :rtype: AudioSegment + """ samples, sample_rate = soundfile.read( io.BytesIO(bytes), dtype='float32') return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): + """Save audio segment to disk as wav file. + + :param filepath: WAV filepath or file object to save the + audio segment. + :type filepath: basestring|file + :param dtype: Subtype for audio file. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :raises TypeError: If dtype is not supported. + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + subtype_map = { + 'int16': 'PCM_16', + 'int32': 'PCM_32', + 'float32': 'FLOAT', + 'float64': 'DOUBLE' + } + soundfile.write( + filepath, + samples, + self._sample_rate, + format='WAV', + subtype=subtype_map[dtype]) + + def to_bytes(self, dtype='float32'): + """Create a byte string containing the audio content. + + :param dtype: Data type for export samples. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :return: Byte string containing audio content. + :rtype: str + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + return samples.tostring() + def apply_gain(self, gain): - self.samples *= 10.**(gain / 20.) + """Apply gain in decibels to samples. + + Note that this is an in-place transformation. + + :param gain: Gain in decibels to apply to samples. + :type gain: float + """ + self._samples *= 10.**(gain / 20.) + + def change_speed(self, speed_rate): + """Change the audio speed by linear interpolation. + + Note that this is an in-place transformation. + + :param speed_rate: Rate of speed change: + speed_rate > 1.0, speed up the audio; + speed_rate = 1.0, unchanged; + speed_rate < 1.0, slow down the audio; + speed_rate <= 0.0, not allowed, raise ValueError. + :type speed_rate: float + :raises ValueError: If speed_rate <= 0.0. + """ + if speed_rate <= 0: + raise ValueError("speed_rate should be greater than zero.") + old_length = self._samples.shape[0] + new_length = int(old_length / speed_rate) + old_indices = np.arange(old_length) + new_indices = np.linspace(start=0, stop=old_length, num=new_length) + self._samples = np.interp(new_indices, old_indices, self._samples) + + def normalize(self, target_sample_rate): + raise NotImplementedError() def resample(self, target_sample_rate): raise NotImplementedError() - def change_speed(self, rate): + def pad_silence(self, duration, sides='both'): + raise NotImplementedError() + + def subsegment(self, start_sec=None, end_sec=None): + raise NotImplementedError() + + def convolve(self, filter, allow_resample=False): + raise NotImplementedError() + + def convolve_and_normalize(self, filter, allow_resample=False): raise NotImplementedError() @property def samples(self): + """Return audio samples. + + :return: Audio samples. + :rtype: ndarray + """ return self._samples.copy() @property def sample_rate(self): + """Return audio sample rate. + + :return: Audio sample rate. + :rtype: int + """ return self._sample_rate @property - def duration(self): - return self._samples.shape[0] / float(self._sample_rate) - + def num_samples(self): + """Return number of samples. -class SpeechSegment(AudioSegment): - def __init__(self, samples, sample_rate, transcript): - AudioSegment.__init__(self, samples, sample_rate) - self._transcript = transcript + :return: Number of samples. + :rtype: int + """ + return self._samples.shape(0) - @classmethod - def from_file(cls, filepath, transcript): - audio = AudioSegment.from_file(filepath) - return cls(audio.samples, audio.sample_rate, transcript) + @property + def duration(self): + """Return audio duration. - @classmethod - def from_bytes(cls, bytes, transcript): - audio = AudioSegment.from_bytes(bytes) - return cls(audio.samples, audio.sample_rate, transcript) + :return: Audio duration in seconds. + :rtype: float + """ + return self._samples.shape[0] / float(self._sample_rate) @property - def transcript(self): - return self._transcript + def rms_db(self): + """Return root mean square energy of the audio in decibels. + + :return: Root mean square energy in decibels. + :rtype: float + """ + # square root => multiply by 10 instead of 20 for dBs + mean_square = np.mean(self._samples**2) + return 10 * np.log10(mean_square) + + def _convert_samples_to_float32(self, samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + def _convert_samples_from_float32(self, samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + This is for writing a audio file. + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py index 3a1426a1f1cf941b7570e12f2416e5f3ee360e57..abe1a0ec89c5d6fc6f8ac1822df184cf5db4d7e1 100755 --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -1,38 +1,80 @@ +"""Contains the data augmentation pipeline.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import random -from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor +from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor class AugmentationPipeline(object): + """Build a pre-processing pipeline with various augmentation models.Such a + data augmentation pipeline is oftern leveraged to augment the training + samples to make the model invariant to certain types of perturbations in the + real world, improving model's generalization ability. + + The pipeline is built according the the augmentation configuration in json + string, e.g. + + .. code-block:: + + '[{"type": "volume", + "params": {"min_gain_dBFS": -15, + "max_gain_dBFS": 15}, + "prob": 0.5}, + {"type": "speed", + "params": {"min_speed_rate": 0.8, + "max_speed_rate": 1.2}, + "prob": 0.5} + ]' + + This augmentation configuration inserts two augmentation models + into the pipeline, with one is VolumePerturbAugmentor and the other + SpeedPerturbAugmentor. "prob" indicates the probability of the current + augmentor to take effect. + + :param augmentation_config: Augmentation configuration in json string. + :type augmentation_config: str + :param random_seed: Random seed. + :type random_seed: int + :raises ValueError: If the augmentation json config is in incorrect format". + """ + def __init__(self, augmentation_config, random_seed=0): self._rng = random.Random(random_seed) self._augmentors, self._rates = self._parse_pipeline_from( augmentation_config) def transform_audio(self, audio_segment): + """Run the pre-processing pipeline for data augmentation. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to process. + :type audio_segment: AudioSegmenet|SpeechSegment + """ for augmentor, rate in zip(self._augmentors, self._rates): if self._rng.uniform(0., 1.) <= rate: augmentor.transform_audio(audio_segment) def _parse_pipeline_from(self, config_json): + """Parse the config json to build a augmentation pipelien.""" try: configs = json.loads(config_json) + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in configs + ] + rates = [config["prob"] for config in configs] except Exception as e: - raise ValueError("Augmentation config json format error: " + raise ValueError("Failed to parse the augmentation config json: " "%s" % str(e)) - augmentors = [ - self._get_augmentor(config["type"], config["params"]) - for config in configs - ] - rates = [config["rate"] for config in configs] return augmentors, rates def _get_augmentor(self, augmentor_type, params): - if augmentor_type == "volumn": - return VolumnPerturbAugmentor(self._rng, **params) + """Return an augmentation model by the type name, and pass in params.""" + if augmentor_type == "volume": + return VolumePerturbAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py index e801b9b1893c3b23c4c8ee92df88ac001cf9b2eb..a323165aaeefb8135e7189a47a388a565afd8c8a 100755 --- a/data_utils/augmentor/base.py +++ b/data_utils/augmentor/base.py @@ -1,3 +1,4 @@ +"""Contains the abstract base class for augmentation models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,6 +7,11 @@ from abc import ABCMeta, abstractmethod class AugmentorBase(object): + """Abstract base class for augmentation model (augmentor) class. + All augmentor classes should inherit from this class, and implement the + following abstract methods. + """ + __metaclass__ = ABCMeta @abstractmethod @@ -14,4 +20,14 @@ class AugmentorBase(object): @abstractmethod def transform_audio(self, audio_segment): + """Adds various effects to the input audio segment. Such effects + will augment the training data to make the model invariant to certain + types of perturbations in the real world, improving model's + generalization ability. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ pass diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py new file mode 100755 index 0000000000000000000000000000000000000000..a5a9f6cadac13e651dd6902d68d0efdaa9a61dc4 --- /dev/null +++ b/data_utils/augmentor/volume_perturb.py @@ -0,0 +1,40 @@ +"""Contains the volume perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class VolumePerturbAugmentor(AugmentorBase): + """Augmentation model for adding random volume perturbation. + + This is used for multi-loudness training of PCEN. See + + https://arxiv.org/pdf/1607.05666v1.pdf + + for more details. + + :param rng: Random generator object. + :type rng: random.Random + :param min_gain_dBFS: Minimal gain in dBFS. + :type min_gain_dBFS: float + :param max_gain_dBFS: Maximal gain in dBFS. + :type max_gain_dBFS: float + """ + + def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): + self._min_gain_dBFS = min_gain_dBFS + self._max_gain_dBFS = max_gain_dBFS + self._rng = rng + + def transform_audio(self, audio_segment): + """Change audio loadness. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + audio_segment.apply_gain(gain) diff --git a/data_utils/augmentor/volumn_perturb.py b/data_utils/augmentor/volumn_perturb.py deleted file mode 100755 index dd1ba53a7c2d6855214673b2026ffd25f333267f..0000000000000000000000000000000000000000 --- a/data_utils/augmentor/volumn_perturb.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import random -from data_utils.augmentor.base import AugmentorBase - - -class VolumnPerturbAugmentor(AugmentorBase): - def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): - self._min_gain_dBFS = min_gain_dBFS - self._max_gain_dBFS = max_gain_dBFS - self._rng = rng - - def transform_audio(self, audio_segment): - gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) - audio_segment.apply_gain(gain) diff --git a/data_utils/data.py b/data_utils/data.py index 630007932b6fce7699e2d3169cef1754082ac1be..48e03fe85d70b61686d189110154c42df1374f91 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -1,8 +1,6 @@ +"""Contains data generator for orgnaizing various audio data preprocessing +pipeline and offering data reader interface of PaddlePaddle requirements. """ - Providing basic audio data preprocessing pipeline, and offering - both instance-level and batch-level data reader interfaces. -""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -13,42 +11,41 @@ import paddle.v2 as paddle from data_utils import utils from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.featurizer.speech_featurizer import SpeechFeaturizer -from data_utils.audio import SpeechSegment +from data_utils.speech import SpeechSegment from data_utils.normalizer import FeatureNormalizer class DataGenerator(object): """ DataGenerator provides basic audio data preprocessing pipeline, and offers - both instance-level and batch-level data reader interfaces. - Normalized FFT are used as audio features here. + data reader interfaces of PaddlePaddle requirements. - :param vocab_filepath: Vocabulary file path for indexing tokenized - transcriptions. + :param vocab_filepath: Vocabulary filepath for indexing tokenized + transcripts. :type vocab_filepath: basestring - :param normalizer_manifest_path: Manifest filepath for collecting feature - normalization statistics, e.g. mean, std. - :type normalizer_manifest_path: basestring - :param normalizer_num_samples: Number of instances sampled for collecting - feature normalization statistics. - Default is 100. - :type normalizer_num_samples: int - :param max_duration: Audio clips with duration (in seconds) greater than - this will be discarded. Default is 20.0. + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param augmentation_config: Augmentation configuration in json string. + Details see AugmentationPipeline.__doc__. + :type augmentation_config: str + :param max_duration: Audio with duration (in seconds) greater than + this will be discarded. :type max_duration: float - :param min_duration: Audio clips with duration (in seconds) smaller than - this will be discarded. Default is 0.0. + :param min_duration: Audio with duration (in seconds) smaller than + this will be discarded. :type min_duration: float :param stride_ms: Striding size (in milliseconds) for generating frames. - Default is 10.0. :type stride_ms: float - :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_frequency: Maximun frequency for FFT features. FFT features of - frequency larger than this will be discarded. - If set None, all features will be kept. - Default is None. - :type max_frequency: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param random_seed: Random seed. + :type random_seed: int """ def __init__(self, @@ -60,6 +57,7 @@ class DataGenerator(object): stride_ms=10.0, window_ms=20.0, max_freq=None, + specgram_type='linear', random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -68,46 +66,49 @@ class DataGenerator(object): augmentation_config=augmentation_config, random_seed=random_seed) self._speech_featurizer = SpeechFeaturizer( vocab_filepath=vocab_filepath, + specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, - max_freq=max_freq, - random_seed=random_seed) + max_freq=max_freq) self._rng = random.Random(random_seed) self._epoch = 0 def batch_reader_creator(self, manifest_path, batch_size, + min_batch_size=1, padding_to=-1, flatten=False, sortagrad=False, batch_shuffle=False): """ - Batch data reader creator for audio data. Creat a callable function to - produce batches of data. + Batch data reader creator for audio data. Return a callable generator + function to produce batches of data. - Audio features will be padded with zeros to make each instance in the - batch to share the same audio feature shape. + Audio features within one batch will be padded with zeros to have the + same shape, or a user-defined shape. - :param manifest_path: Filepath of manifest for audio clip files. + :param manifest_path: Filepath of manifest for audio files. :type manifest_path: basestring - :param batch_size: Instance number in a batch. + :param batch_size: Number of instances in a batch. :type batch_size: int - :param padding_to: If set -1, the maximun column numbers in the batch - will be used as the target size for padding. - Otherwise, `padding_to` will be the target size. - Default is -1. + :param min_batch_size: Any batch with batch size smaller than this will + be discarded. (To be deprecated in the future.) + :type min_batch_size: int + :param padding_to: If set -1, the maximun shape in the batch + will be used as the target shape for padding. + Otherwise, `padding_to` will be the target shape. :type padding_to: int - :param flatten: If set True, audio data will be flatten to be a 1-dim - ndarray. Otherwise, 2-dim ndarray. Default is False. + :param flatten: If set True, audio features will be flatten to 1darray. :type flatten: bool - :param sortagrad: Sort the audio clips by duration in the first epoc - if set True. + :param sortagrad: If set True, sort the instances by audio duration + in the first epoch for speed up training. :type sortagrad: bool - :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, but a - specific batch-wise shuffle. For more details, - please see `_batch_shuffle` function. + :param batch_shuffle: If set True, instances are batch-wise shuffled. + For more details, please see + ``_batch_shuffle.__doc__``. + If sortagrad is True, batch_shuffle is disabled + for the first epoch. :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable @@ -132,7 +133,7 @@ class DataGenerator(object): if len(batch) == batch_size: yield self._padding_batch(batch, padding_to, flatten) batch = [] - if len(batch) > 0: + if len(batch) >= min_batch_size: yield self._padding_batch(batch, padding_to, flatten) self._epoch += 1 @@ -140,20 +141,33 @@ class DataGenerator(object): @property def feeding(self): - """Returns data_reader's feeding dict.""" + """Returns data reader's feeding dict. + + :return: Data feeding dict. + :rtype: dict + """ return {"audio_spectrogram": 0, "transcript_text": 1} @property def vocab_size(self): - """Returns vocabulary size.""" + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return self._speech_featurizer.vocab_size @property def vocab_list(self): - """Returns vocabulary list.""" + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._speech_featurizer.vocab_list def _process_utterance(self, filename, transcript): + """Load, augment, featurize and normalize for speech data.""" speech_segment = SpeechSegment.from_file(filename, transcript) self._augmentation_pipeline.transform_audio(speech_segment) specgram, text_ids = self._speech_featurizer.featurize(speech_segment) @@ -162,16 +176,11 @@ class DataGenerator(object): def _instance_reader_creator(self, manifest): """ - Instance reader creator for audio data. Creat a callable function to - produce instances of data. + Instance reader creator. Create a callable function to produce + instances of data. - Instance: a tuple of a numpy ndarray of audio spectrogram and a list of - tokenized and indexed transcription text. - - :param manifest: Filepath of manifest for audio clip files. - :type manifest: basestring - :return: Data reader function. - :rtype: callable + Instance: a tuple of ndarray of audio spectrogram and a list of + token indices for transcript. """ def reader(): @@ -183,24 +192,22 @@ class DataGenerator(object): def _padding_batch(self, batch, padding_to=-1, flatten=False): """ - Padding audio part of features (only in the time axis -- column axis) - with zeros, to make each instance in the batch share the same - audio feature shape. + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. - If `padding_to` is set -1, the maximun column numbers in the batch will - be used as the target size. Otherwise, `padding_to` will be the target - size. Default is -1. + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). - If `flatten` is set True, audio data will be flatten to be a 1-dim - ndarray. Default is False. + If `flatten` is True, features will be flatten to 1darray. """ new_batch = [] # get target shape max_length = max([audio.shape[1] for audio, text in batch]) if padding_to != -1: if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be greater" - " or equal to the original instance length.") + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") max_length = padding_to # padding for audio, text in batch: @@ -212,28 +219,21 @@ class DataGenerator(object): return new_batch def _batch_shuffle(self, manifest, batch_size): - """ - The instances have different lengths and they cannot be - combined into a single matrix multiplication. It usually - sorts the training examples by length and combines only - similarly-sized instances into minibatches, pads with - silence when necessary so that all instances in a batch - have the same length. This batch shuffle fuction is used - to make similarly-sized instances into minibatches and - make a batch-wise shuffle. + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. 1. Sort the audio clips by duration. 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_size. + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. 4. Shuffle the minibatches. - :param manifest: manifest file. + :param manifest: Manifest contents. List of dict. :type manifest: list :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int - :return: batch shuffled mainifest. + :return: Batch shuffled mainifest. :rtype: list """ manifest.sort(key=lambda x: x["duration"]) diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 5d9c6883662993b1a3f915cf076e4f4ced70d475..9f9d4e505d13b4fcaf1c1411821163caa4b73bc8 100755 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -1,30 +1,54 @@ +"""Contains the audio featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -import random from data_utils import utils from data_utils.audio import AudioSegment class AudioFeaturizer(object): + """Audio featurizer, for extracting features from audio contents of + AudioSegment or SpeechSegment. + + Currently, it only supports feature type of linear spectrogram. + + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + """ + def __init__(self, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None, - random_seed=0): + max_freq=None): self._specgram_type = specgram_type self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq def featurize(self, audio_segment): + """Extract audio features from AudioSegment or SpeechSegment. + + :param audio_segment: Audio/speech segment to extract features from. + :type audio_segment: AudioSegment|SpeechSegment + :return: Spectrogram audio feature in 2darray. + :rtype: ndarray + """ return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) def _compute_specgram(self, samples, sample_rate): + """Extract various audio features.""" if self._specgram_type == 'linear': return self._compute_linear_specgram( samples, sample_rate, self._stride_ms, self._window_ms, @@ -40,9 +64,7 @@ class AudioFeaturizer(object): window_ms=20.0, max_freq=None, eps=1e-14): - """Laod audio data and calculate the log of spectrogram by FFT. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ + """Compute the linear spectrogram from FFT energy.""" if max_freq is None: max_freq = sample_rate / 2 if max_freq > sample_rate / 2: @@ -62,9 +84,7 @@ class AudioFeaturizer(object): return np.log(specgram[:ind, :] + eps) def _specgram_real(self, samples, window_size, stride_size, sample_rate): - """Compute the spectrogram by FFT for a discrete real signal. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ + """Compute the spectrogram for samples from a real signal.""" # extract strided windows truncate_size = (len(samples) - window_size) % stride_size samples = samples[:len(samples) - truncate_size] diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 06af7a026de4a5c991b2470603e7c465cbd87cb5..7702045597fb8379bffee2c31029ace4b2453f92 100755 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -1,3 +1,4 @@ +"""Contains the speech featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -7,26 +8,70 @@ from data_utils.featurizer.text_featurizer import TextFeaturizer class SpeechFeaturizer(object): + """Speech featurizer, for extracting features from both audio and transcript + contents of SpeechSegment. + + Currently, for audio parts, it only supports feature type of linear + spectrogram; for transcript parts, it only supports char-level tokenizing + and conversion into a list of token indices. Note that the token indexing + order follows the given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + """ + def __init__(self, vocab_filepath, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None, - random_seed=0): - self._audio_featurizer = AudioFeaturizer( - specgram_type, stride_ms, window_ms, max_freq, random_seed) + max_freq=None): + self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, + window_ms, max_freq) self._text_featurizer = TextFeaturizer(vocab_filepath) def featurize(self, speech_segment): + """Extract features for speech segment. + + 1. For audio parts, extract the audio features. + 2. For transcript parts, convert text string to a list of token indices + in char-level. + + :param audio_segment: Speech segment to extract features from. + :type audio_segment: SpeechSegment + :return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of + char-level token indices. + :rtype: tuple + """ audio_feature = self._audio_featurizer.featurize(speech_segment) - text_ids = self._text_featurizer.text2ids(speech_segment.transcript) + text_ids = self._text_featurizer.featurize(speech_segment.transcript) return audio_feature, text_ids @property def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return self._text_featurizer.vocab_size @property def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._text_featurizer.vocab_list diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py index 7e4b69d7b302009a766e2c53f8597e690858aab6..4f9a49b594010f91a64797b9a4b7e9054d4749d5 100755 --- a/data_utils/featurizer/text_featurizer.py +++ b/data_utils/featurizer/text_featurizer.py @@ -1,3 +1,4 @@ +"""Contains the text featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,26 +7,53 @@ import os class TextFeaturizer(object): + """Text featurizer, for processing or extracting features from text. + + Currently, it only supports char-level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + """ + def __init__(self, vocab_filepath): self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) - def text2ids(self, text): + def featurize(self, text): + """Convert text string to a list of token indices in char-level.Note + that the token indexing order follows the given vocabulary file. + + :param text: Text to process. + :type text: basestring + :return: List of char-level token indices. + :rtype: list + """ tokens = self._char_tokenize(text) return [self._vocab_dict[token] for token in tokens] - def ids2text(self, ids): - return ''.join([self._vocab_list[id] for id in ids]) - @property def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return len(self._vocab_list) @property def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._vocab_list def _char_tokenize(self, text): + """Character tokenizer.""" return list(text.strip()) def _load_vocabulary_from_file(self, vocab_filepath): diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py index 364600af8b10f64bd566213cae5dffb8661b564f..c123d25d20600140b47da1e93655b15c0053dfea 100755 --- a/data_utils/normalizer.py +++ b/data_utils/normalizer.py @@ -1,3 +1,4 @@ +"""Contains feature normalizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -9,6 +10,28 @@ from data_utils.audio import AudioSegment class FeatureNormalizer(object): + """Feature normalizer. Normalize features to be of zero mean and unit + stddev. + + if mean_std_filepath is provided (not None), the normalizer will directly + initilize from the file. Otherwise, both manifest_path and featurize_func + should be given for on-the-fly mean and stddev computing. + + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param manifest_path: Manifest of instances for computing mean and stddev. + :type meanifest_path: None|basestring + :param featurize_func: Function to extract features. It should be callable + with ``featurize_func(audio_segment)``. + :type featurize_func: None|callable + :param num_samples: Number of random samples for computing mean and stddev. + :type num_samples: int + :param random_seed: Random seed for sampling instances. + :type random_seed: int + :raises ValueError: If both mean_std_filepath and manifest_path + (or both mean_std_filepath and featurize_func) are None. + """ + def __init__(self, mean_std_filepath, manifest_path=None, @@ -25,18 +48,33 @@ class FeatureNormalizer(object): self._read_mean_std_from_file(mean_std_filepath) def apply(self, features, eps=1e-14): - """Normalize features to be of zero mean and unit stddev.""" + """Normalize features to be of zero mean and unit stddev. + + :param features: Input features to be normalized. + :type features: ndarray + :param eps: added to stddev to provide numerical stablibity. + :type eps: float + :return: Normalized features. + :rtype: ndarray + """ return (features - self._mean) / (self._std + eps) def write_to_file(self, filepath): + """Write the mean and stddev to the file. + + :param filepath: File to write mean and stddev. + :type filepath: basestring + """ np.savez(filepath, mean=self._mean, std=self._std) def _read_mean_std_from_file(self, filepath): + """Load mean and std from file.""" npzfile = np.load(filepath) self._mean = npzfile["mean"] self._std = npzfile["std"] def _compute_mean_std(self, manifest_path, featurize_func, num_samples): + """Compute mean and std from randomly sampled instances.""" manifest = utils.read_manifest(manifest_path) sampled_manifest = self._rng.sample(manifest, num_samples) features = [] diff --git a/data_utils/speech.py b/data_utils/speech.py new file mode 100755 index 0000000000000000000000000000000000000000..48db595b41b82933f9b5c16cab7d2ee24f9a2ecc --- /dev/null +++ b/data_utils/speech.py @@ -0,0 +1,75 @@ +"""Contains the speech segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.audio import AudioSegment + + +class SpeechSegment(AudioSegment): + """Speech segment abstraction, a subclass of AudioSegment, + with an additional transcript. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :param transcript: Transcript text for the speech. + :type transript: basestring + :raises TypeError: If the sample data type is not float or int. + """ + + def __init__(self, samples, sample_rate, transcript): + AudioSegment.__init__(self, samples, sample_rate) + self._transcript = transcript + + def __eq__(self, other): + """Return whether two objects are equal. + """ + if not AudioSegment.__eq__(self, other): + return False + if self._transcript != other._transcript: + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + @classmethod + def from_file(cls, filepath, transcript): + """Create speech segment from audio file and corresponding transcript. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def from_bytes(cls, bytes, transcript): + """Create speech segment from a byte string and corresponding + transcript. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_bytes(bytes) + return cls(audio.samples, audio.sample_rate, transcript) + + @property + def transcript(self): + """Return the transcript text. + + :return: Transcript text for the speech. + :rtype: basestring + """ + return self._transcript diff --git a/data_utils/utils.py b/data_utils/utils.py index 2a916b54fc78ff8a865da6e740729564cf1cfafb..3f1165718aa0e2a0bf0687b8a613a6447b964ee8 100755 --- a/data_utils/utils.py +++ b/data_utils/utils.py @@ -1,3 +1,4 @@ +"""Contains data helper functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,7 +7,21 @@ import json def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): - """Load and parse manifest file.""" + """Load and parse manifest file. + + Instances with durations outside [min_duration, max_duration] will be + filtered out. + + :param manifest_path: Manifest file to load and parse. + :type manifest_path: basestring + :param max_duration: Maximal duration in seconds for instance filter. + :type max_duration: float + :param min_duration: Minimal duration in seconds for instance filter. + :type min_duration: float + :return: Manifest parsing results. List of dict. + :rtype: list + :raises IOError: If failed to parse the manifest. + """ manifest = [] for json_line in open(manifest_path): try: diff --git a/datasets/librispeech/librispeech.py b/datasets/librispeech/librispeech.py index 1ba2a442214136fb694f7002fbe677a10a3f2e51..faf038cc1919e3659e39d2f06b58816f3b72ba12 100644 --- a/datasets/librispeech/librispeech.py +++ b/datasets/librispeech/librispeech.py @@ -1,13 +1,14 @@ -""" - Download, unpack and create manifest json files for the Librespeech dataset. +"""Prepare Librispeech ASR datasets. - A manifest is a json file summarizing filelist in a data set, with each line - containing the meta data (i.e. audio filepath, transcription text, audio - duration) of each audio file in the data set. +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -import paddle.v2 as paddle -from paddle.v2.dataset.common import md5file import distutils.util import os import wget @@ -15,6 +16,7 @@ import tarfile import argparse import soundfile import json +from paddle.v2.dataset.common import md5file DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/decoder.py b/decoder.py index 7c4b952636f3e94167bbd00880673a8dc5635803..8314885ce609f4e3da6814cc5831f2e1dd2029ff 100755 --- a/decoder.py +++ b/decoder.py @@ -1,9 +1,10 @@ -""" - CTC-like decoder utilitis. -""" +"""Contains various CTC decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -from itertools import groupby import numpy as np +from itertools import groupby def ctc_best_path_decode(probs_seq, vocabulary): diff --git a/infer.py b/infer.py index eb31254cebe6c690de5987f39e02dcd60571b74f..f7c99df117985058eede89301b6339bbaf4f46c2 100644 --- a/infer.py +++ b/infer.py @@ -1,7 +1,4 @@ -""" - Inference for a simplifed version of Baidu DeepSpeech2 model. -""" - +"""Inferer for DeepSpeech2 model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/model.py b/model.py index 13ff829b9a6b947253a40a1d3ea524de141bd9d1..cb0b4ecbba1a3fb435a5f625a54d6e5bebe689e0 100644 --- a/model.py +++ b/model.py @@ -1,11 +1,10 @@ -""" - A simplifed version of Baidu DeepSpeech2 model. -""" +"""Contains DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import paddle.v2 as paddle -#TODO: add bidirectional rnn. - def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, padding, act): diff --git a/train.py b/train.py index c6aa97527d1e182824536eabb62096e806ad9cf2..7ac4626f4c094a3774336f50ce57fb1e9f495296 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,4 @@ -""" - Trainer for a simplifed version of Baidu DeepSpeech2 model. -""" - +"""Trainer for DeepSpeech2 model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -164,7 +161,7 @@ def train(): print("\nPass: %d, Batch: %d, TrainCost: %f" % (event.pass_id, event.batch_id, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params.tar.gz", 'w') as f: + with gzip.open("params_tmp.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.')