diff --git a/deep_speech_2/audio_data_utils.py b/deep_speech_2/audio_data_utils.py index a3a397e9453c95884aceaec6904c472a387164d8..7d09d612a0b09d49253f3e026b686f9009c2dc58 100644 --- a/deep_speech_2/audio_data_utils.py +++ b/deep_speech_2/audio_data_utils.py @@ -1,5 +1,6 @@ """ - Audio data preprocessing tools and reader creators. + Providing basic audio data preprocessing pipeline, and offering + both instance-level and batch-level data reader interfaces. """ import paddle.v2 as paddle import logging @@ -9,143 +10,201 @@ import soundfile import numpy as np import os -# TODO: add z-score normalization. - -ENGLISH_CHAR_VOCAB_FILEPATH = "eng_vocab.txt" - +RANDOM_SEED = 0 logger = logging.getLogger(__name__) -def spectrogram_from_file(filename, - stride_ms=10, - window_ms=20, - max_freq=None, - eps=1e-14): - """ - Calculate the log of linear spectrogram from FFT energy - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - audio, sample_rate = soundfile.read(filename) - if audio.ndim >= 2: - audio = np.mean(audio, 1) - if max_freq is None: - max_freq = sample_rate / 2 - if max_freq > sample_rate / 2: - raise ValueError("max_freq must be greater than half of " - "sample rate.") - if stride_ms > window_ms: - raise ValueError("Stride size must not be greater than window size.") - stride_size = int(0.001 * sample_rate * stride_ms) - window_size = int(0.001 * sample_rate * window_ms) - spectrogram, freqs = extract_spectrogram( - audio, - window_size=window_size, - stride_size=stride_size, - sample_rate=sample_rate) - ind = np.where(freqs <= max_freq)[0][-1] + 1 - return np.log(spectrogram[:ind, :] + eps) - - -def extract_spectrogram(samples, window_size, stride_size, sample_rate): +class DataGenerator(object): """ - Compute the spectrogram for a real discrete signal. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - # extract strided windows - truncate_size = (len(samples) - window_size) % stride_size - samples = samples[:len(samples) - truncate_size] - nshape = (window_size, (len(samples) - window_size) // stride_size + 1) - nstrides = (samples.strides[0], samples.strides[0] * stride_size) - windows = np.lib.stride_tricks.as_strided( - samples, shape=nshape, strides=nstrides) - assert np.all( - windows[:, 1] == samples[stride_size:(stride_size + window_size)]) - # window weighting, compute squared Fast Fourier Transform (fft), scaling - weighting = np.hanning(window_size)[:, None] - fft = np.fft.rfft(windows * weighting, axis=0) - fft = np.absolute(fft)**2 - scale = np.sum(weighting**2) * sample_rate - fft[1:-1, :] *= (2.0 / scale) - fft[(0, -1), :] /= scale - # prepare fft frequency list - freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) - return fft, freqs - - -def vocabulary_from_file(vocabulary_path): - """ - Load vocabulary from file. + DataGenerator provides basic audio data preprocessing pipeline, and offer + both instance-level and batch-level data reader interfaces. + Normalized FFT are used as audio features here. + + :param vocab_filepath: Vocabulary file path for indexing tokenized + transcriptions. + :type vocab_filepath: basestring + :param normalizer_manifest_path: Manifest filepath for collecting feature + normalization statistics, e.g. mean, std. + :type normalizer_manifest_path: basestring + :param normalizer_num_samples: Number of instances sampled for collecting + feature normalization statistics. + Default is 100. + :type normalizer_num_samples: int + :param max_duration: Audio clips with duration (in seconds) greater than + this will be discarded. Default is 20.0. + :type max_duration: float + :param min_duration: Audio clips with duration (in seconds) smaller than + this will be discarded. Default is 0.0. + :type min_duration: float + :param stride_ms: Striding size (in milliseconds) for generating frames. + Default is 10.0. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :type window_ms: float + :param max_frequency: Maximun frequency for FFT features. FFT features of + frequency larger than this will be discarded. + If set None, all features will be kept. + Default is None. + :type max_frequency: float """ - if os.path.exists(vocabulary_path): - vocab_lines = [] - with open(vocabulary_path, 'r') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] - vocab_dict = dict( - [(token, id) for (id, token) in enumerate(vocab_list)]) - return vocab_dict, vocab_list - else: - raise ValueError("Vocabulary file %s not found.", vocabulary_path) + def __init__(self, + vocab_filepath, + normalizer_manifest_path, + normalizer_num_samples=100, + max_duration=20.0, + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_frequency=None): + self.__max_duration__ = max_duration + self.__min_duration__ = min_duration + self.__stride_ms__ = stride_ms + self.__window_ms__ = window_ms + self.__max_frequency__ = max_frequency + self.__random__ = random.Random(RANDOM_SEED) + # load vocabulary (dictionary) + self.__vocab_dict__, self.__vocab_list__ = \ + self.__load_vocabulary_from_file__(vocab_filepath) + # collect normalizer statistics + self.__mean__, self.__std__ = self.__collect_normalizer_statistics__( + manifest_path=normalizer_manifest_path, + num_samples=normalizer_num_samples) -def get_vocabulary_size(): - """ - Get vocabulary size. - """ - vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) - return len(vocab_dict) + def __audio_featurize__(self, audio_filename): + """ + Preprocess audio data, including feature extraction, normalization etc.. + """ + features = self.__audio_basic_featurize__(audio_filename) + return self.__normalize__(features) + def __text_featurize__(self, text): + """ + Preprocess text data, including tokenizing and token indexing etc.. + """ + return self.__convert_text_to_char_index__( + text=text, vocabulary=self.__vocab_dict__) -def get_vocabulary(): - """ - Get vocabulary. - """ - return vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) + def __audio_basic_featurize__(self, audio_filename): + """ + Compute basic (without normalization etc.) features for audio data. + """ + return self.__spectrogram_from_file__( + filename=audio_filename, + stride_ms=self.__stride_ms__, + window_ms=self.__window_ms__, + max_freq=self.__max_frequency__) + def __collect_normalizer_statistics__(self, manifest_path, num_samples=100): + """ + Compute feature normalization statistics, i.e. mean and stddev. + """ + # read manifest + manifest = self.__read_manifest__( + manifest_path=manifest_path, + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + # sample for statistics + sampled_manifest = self.__random__.sample(manifest, num_samples) + # extract spectrogram feature + features = [] + for instance in sampled_manifest: + spectrogram = self.__audio_basic_featurize__( + instance["audio_filepath"]) + features.append(spectrogram) + features = np.hstack(features) + mean = np.mean(features, axis=1).reshape([-1, 1]) + std = np.std(features, axis=1).reshape([-1, 1]) + return mean, std -def parse_transcript(text, vocabulary): - """ - Convert the transcript text string to list of token index integers. - """ - return [vocabulary[w] for w in text] + def __normalize__(self, features, eps=1e-14): + """ + Normalize features to be of zero mean and unit stddev. + """ + return (features - self.__mean__) / (self.__std__ + eps) + def __spectrogram_from_file__(self, + filename, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """ + Laod audio data and calculate the log of spectrogram by FFT. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + audio, sample_rate = soundfile.read(filename) + if audio.ndim >= 2: + audio = np.mean(audio, 1) + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + spectrogram, freqs = self.__extract_spectrogram__( + audio, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + return np.log(spectrogram[:ind, :] + eps) -def reader_creator(manifest_path, - sort_by_duration=True, - shuffle=False, - max_duration=10.0, - min_duration=0.0): - """ - Audio data reader creator. - - Instance: a tuple of a numpy ndarray of audio spectrogram and a list of - tokenized transcription text. - - :param manifest_path: Filepath for Manifest of audio clip files. - :type manifest_path: basestring - :param sort_by_duration: Sort the audio clips by duration if set True. - For SortaGrad. - :type sort_by_duration: bool - :param shuffle: Shuffle the audio clips if set True. - :type shuffle: bool - :param max_duration: Audio clips with duration (in seconds) greater than - this will be discarded. - :type max_duration: float - :param min_duration: Audio clips with duration (in seconds) smaller than - this will be discarded. - :type min_duration: float - :return: Data reader function. - :rtype: callable - """ - if sort_by_duration and shuffle: - sort_by_duration = False - logger.warn("When shuffle set to true, " - "sort_by_duration is forced to set False.") - vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) + def __extract_spectrogram__(self, samples, window_size, stride_size, + sample_rate): + """ + Compute the spectrogram by FFT for a discrete real signal. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + # extract strided windows + truncate_size = (len(samples) - window_size) % stride_size + samples = samples[:len(samples) - truncate_size] + nshape = (window_size, (len(samples) - window_size) // stride_size + 1) + nstrides = (samples.strides[0], samples.strides[0] * stride_size) + windows = np.lib.stride_tricks.as_strided( + samples, shape=nshape, strides=nstrides) + assert np.all( + windows[:, 1] == samples[stride_size:(stride_size + window_size)]) + # window weighting, squared Fast Fourier Transform (fft), scaling + weighting = np.hanning(window_size)[:, None] + fft = np.fft.rfft(windows * weighting, axis=0) + fft = np.absolute(fft)**2 + scale = np.sum(weighting**2) * sample_rate + fft[1:-1, :] *= (2.0 / scale) + fft[(0, -1), :] /= scale + # prepare fft frequency list + freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) + return fft, freqs - def reader(): - # read manifest - manifest_data = [] + def __load_vocabulary_from_file__(self, vocabulary_path): + """ + Load vocabulary from file. + """ + if not os.path.exists(vocabulary_path): + raise ValueError("Vocabulary file %s not found.", vocabulary_path) + vocab_lines = [] + with open(vocabulary_path, 'r') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + vocab_dict = dict( + [(token, id) for (id, token) in enumerate(vocab_list)]) + return vocab_dict, vocab_list + + def __convert_text_to_char_index__(self, text, vocabulary): + """ + Convert text string to a list of character index integers. + """ + return [vocabulary[w] for w in text] + + def __read_manifest__(self, manifest_path, max_duration, min_duration): + """ + Load and parse manifest file. + """ + manifest = [] for json_line in open(manifest_path): try: json_data = json.loads(json_line) @@ -153,63 +212,172 @@ def reader_creator(manifest_path, raise ValueError("Error reading manifest: %s" % str(e)) if (json_data["duration"] <= max_duration and json_data["duration"] >= min_duration): - manifest_data.append(json_data) - # sort (by duration) or shuffle manifest - if sort_by_duration: - manifest_data.sort(key=lambda x: x["duration"]) - if shuffle: - random.shuffle(manifest_data) - # extract spectrogram feature - for instance in manifest_data: - spectrogram = spectrogram_from_file(instance["audio_filepath"]) - text = parse_transcript(instance["text"], vocab_dict) - yield (spectrogram, text) + manifest.append(json_data) + return manifest - return reader + def __padding_batch__(self, batch, padding_to=-1, flatten=False): + """ + Padding audio part of features (only in the time axis -- column axis) + with zeros, to make each instance in the batch share the same + audio feature shape. + If `padding_to` is set -1, the maximun column numbers in the batch will + be used as the target size. Otherwise, `padding_to` will be the target + size. Default is -1. -def padding_batch_reader(batch_reader, padding=[-1, -1], flatten=True): - """ - Padding for batches. Return a batch reader. - - Each instance in a batch will be padded to be of a same target shape. - The target shape is the largest shape among all the batch instances and - 'padding' argument. Therefore, if padding is set [-1, -1], instance will be - padded to have the same shape just within each batch and the shape will - be different across batches; if padding is set - [VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to - have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM]. - - :param batch_reader: Input batch reader. - :type batch_reader: callable - :param padding: Padding pattern. Details please refer to the above. - :type padding: list - :param flatten: Flatten the tensor to be one dimension. - :type flatten: bool - :return: Batch reader function. - :rtype: callable - """ - - def padding_batch(batch): + If `flatten` is set True, audio data will be flatten to be a 1-dim + ndarray. Default is False. + """ new_batch = [] - # get target shape within batch - nshape_list = [padding] - for audio, text in batch: - nshape_list.append(audio.shape) - target_shape = np.array(nshape_list).max(axis=0) + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be greater" + " or equal to the original instance length.") + max_length = padding_to # padding for audio, text in batch: - pad_shape = target_shape - audio.shape - assert np.all(pad_shape >= 0) - padded_audio = np.pad( - audio, [(0, pad_shape[0]), (0, pad_shape[1])], mode="constant") + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio if flatten: padded_audio = padded_audio.flatten() new_batch.append((padded_audio, text)) return new_batch - def new_batch_reader(): - for batch in batch_reader(): - yield padding_batch(batch) + def instance_reader_creator(self, + manifest_path, + sort_by_duration=True, + shuffle=False): + """ + Instance reader creator for audio data. Creat a callable function to + produce instances of data. + + Instance: a tuple of a numpy ndarray of audio spectrogram and a list of + tokenized and indexed transcription text. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param sort_by_duration: Sort the audio clips by duration if set True + (for SortaGrad). + :type sort_by_duration: bool + :param shuffle: Shuffle the audio clips if set True. + :type shuffle: bool + :return: Data reader function. + :rtype: callable + """ + if sort_by_duration and shuffle: + sort_by_duration = False + logger.warn("When shuffle set to true, " + "sort_by_duration is forced to set False.") + + def reader(): + # read manifest + manifest = self.__read_manifest__( + manifest_path=manifest_path, + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + # sort (by duration) or shuffle manifest + if sort_by_duration: + manifest.sort(key=lambda x: x["duration"]) + if shuffle: + self.__random__.shuffle(manifest) + # extract spectrogram feature + for instance in manifest: + spectrogram = self.__audio_featurize__( + instance["audio_filepath"]) + transcript = self.__text_featurize__(instance["text"]) + yield (spectrogram, transcript) + + return reader + + def batch_reader_creator(self, + manifest_path, + batch_size, + padding_to=-1, + flatten=False, + sort_by_duration=True, + shuffle=False): + """ + Batch data reader creator for audio data. Creat a callable function to + produce batches of data. + + Audio features will be padded with zeros to make each instance in the + batch to share the same audio feature shape. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param batch_size: Instance number in a batch. + :type batch_size: int + :param padding_to: If set -1, the maximun column numbers in the batch + will be used as the target size for padding. + Otherwise, `padding_to` will be the target size. + Default is -1. + :type padding_to: int + :param flatten: If set True, audio data will be flatten to be a 1-dim + ndarray. Otherwise, 2-dim ndarray. Default is False. + :type flatten: bool + :param sort_by_duration: Sort the audio clips by duration if set True + (for SortaGrad). + :type sort_by_duration: bool + :param shuffle: Shuffle the audio clips if set True. + :type shuffle: bool + :return: Batch reader function, producing batches of data when called. + :rtype: callable + """ + + def batch_reader(): + instance_reader = self.instance_reader_creator( + manifest_path=manifest_path, + sort_by_duration=sort_by_duration, + shuffle=shuffle) + batch = [] + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self.__padding_batch__(batch, padding_to, flatten) + batch = [] + if len(batch) > 0: + yield self.__padding_batch__(batch, padding_to, flatten) + + return batch_reader + + def vocabulary_size(self): + """ + Get vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return len(self.__vocab_list__) + + def vocabulary_dict(self): + """ + Get vocabulary in dict. + + :return: Vocabulary in dict. + :rtype: dict + """ + return self.__vocab_dict__ + + def vocabulary_list(self): + """ + Get vocabulary in list. + + :return: Vocabulary in list + :rtype: list + """ + return self.__vocab_list__ + + def data_name_feeding(self): + """ + Get feeddings (data field name and corresponding field id). - return new_batch_reader + :return: Feeding dict. + :rtype: dict + """ + feeding = { + "audio_spectrogram": 0, + "transcript_text": 1, + } + return feeding diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 0d7dd81645c1f2dbe97a035026f94c53a4a0c080..89dcf35c9080ab98b3ba9ba6a2b5621005a2f4a2 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -5,16 +5,18 @@ import paddle.v2 as paddle import argparse import gzip +import time import sys from model import deep_speech2 -import audio_data_utils +from audio_data_utils import DataGenerator +import numpy as np #TODO: add WER metric parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 trainer.') parser.add_argument( - "--batch_size", default=512, type=int, help="Minibatch size.") + "--batch_size", default=32, type=int, help="Minibatch size.") parser.add_argument("--trainer", default=1, type=int, help="Trainer number.") parser.add_argument( "--num_passes", default=20, type=int, help="Training pass number.") @@ -23,7 +25,7 @@ parser.add_argument( parser.add_argument( "--num_rnn_layers", default=5, type=int, help="RNN layer number.") parser.add_argument( - "--rnn_layer_size", default=256, type=int, help="RNN layer cell number.") + "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.") parser.add_argument( "--use_gpu", default=True, type=bool, help="Use gpu or not.") parser.add_argument( @@ -37,13 +39,45 @@ def train(): """ DeepSpeech2 training. """ + # create data readers + data_generator = DataGenerator( + vocab_filepath='eng_vocab.txt', + normalizer_manifest_path='./libri.manifest.train', + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + train_batch_reader_sortagrad = data_generator.batch_reader_creator( + manifest_path='./libri.manifest.dev.small', + batch_size=args.batch_size // args.trainer, + padding_to=2000, + flatten=True, + sort_by_duration=True, + shuffle=False) + train_batch_reader_nosortagrad = data_generator.batch_reader_creator( + manifest_path='./libri.manifest.dev.small', + batch_size=args.batch_size // args.trainer, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=True) + test_batch_reader = data_generator.batch_reader_creator( + manifest_path='./libri.manifest.test', + batch_size=args.batch_size // args.trainer, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + feeding = data_generator.data_name_feeding() + # create network config - dict_size = audio_data_utils.get_vocabulary_size() + dict_size = data_generator.vocabulary_size() audio_data = paddle.layer.data( name="audio_spectrogram", height=161, - width=1000, - type=paddle.data_type.dense_vector(161000)) + width=2000, + type=paddle.data_type.dense_vector(322000)) text_data = paddle.layer.data( name="transcript_text", type=paddle.data_type.integer_value_sequence(dict_size)) @@ -58,47 +92,26 @@ def train(): # create parameters and optimizer parameters = paddle.parameters.create(cost) optimizer = paddle.optimizer.Adam( - learning_rate=5e-4, gradient_clipping_threshold=400) + learning_rate=5e-5, gradient_clipping_threshold=400) trainer = paddle.trainer.SGD( cost=cost, parameters=parameters, update_equation=optimizer) - # create data readers - feeding = { - "audio_spectrogram": 0, - "transcript_text": 1, - } - train_batch_reader_with_sortagrad = audio_data_utils.padding_batch_reader( - paddle.batch( - audio_data_utils.reader_creator( - manifest_path="./libri.manifest.train", sort_by_duration=True), - batch_size=args.batch_size // args.trainer), - padding=[-1, 1000]) - train_batch_reader_without_sortagrad = audio_data_utils.padding_batch_reader( - paddle.batch( - audio_data_utils.reader_creator( - manifest_path="./libri.manifest.train", - sort_by_duration=False, - shuffle=True), - batch_size=args.batch_size // args.trainer), - padding=[-1, 1000]) - test_batch_reader = audio_data_utils.padding_batch_reader( - paddle.batch( - audio_data_utils.reader_creator( - manifest_path="./libri.manifest.dev", sort_by_duration=False), - batch_size=args.batch_size // args.trainer), - padding=[-1, 1000]) # create event handler def event_handler(event): + global start_time if isinstance(event, paddle.event.EndIteration): if event.batch_id % 10 == 0: - print "/nPass: %d, Batch: %d, TrainCost: %f" % ( + print "\nPass: %d, Batch: %d, TrainCost: %f" % ( event.pass_id, event.batch_id, event.cost) else: sys.stdout.write('.') sys.stdout.flush() + if isinstance(event, paddle.event.BeginPass): + start_time = time.time() if isinstance(event, paddle.event.EndPass): result = trainer.test(reader=test_batch_reader, feeding=feeding) - print "Pass: %d, TestCost: %s" % (event.pass_id, result.cost) + print "\n------- Time: %d, Pass: %d, TestCost: %s" % ( + time.time() - start_time, event.pass_id, result.cost) with gzip.open("params.tar.gz", 'w') as f: parameters.to_tar(f) @@ -106,14 +119,14 @@ def train(): # first pass with sortagrad if args.use_sortagrad: trainer.train( - reader=train_batch_reader_with_sortagrad, + reader=train_batch_reader_sortagrad, event_handler=event_handler, num_passes=1, feeding=feeding) args.num_passes -= 1 # other passes without sortagrad trainer.train( - reader=train_batch_reader_without_sortagrad, + reader=train_batch_reader_nosortagrad, event_handler=event_handler, num_passes=args.num_passes, feeding=feeding)