From 7c85e0fdb5ffac76df6f3d99519e344be7c9b5dd Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 7 Jun 2017 16:37:13 +0800 Subject: [PATCH] Support variable input batch and sortagrad. --- audio_data_utils.py | 56 +++++++++++++++++++++++++++++------------ train.py | 61 ++++++++++++++++----------------------------- 2 files changed, 62 insertions(+), 55 deletions(-) diff --git a/audio_data_utils.py b/audio_data_utils.py index c717bcf1..abb7f1e9 100644 --- a/audio_data_utils.py +++ b/audio_data_utils.py @@ -8,6 +8,7 @@ import json import random import soundfile import numpy as np +import itertools import os RANDOM_SEED = 0 @@ -62,6 +63,7 @@ class DataGenerator(object): self.__stride_ms__ = stride_ms self.__window_ms__ = window_ms self.__max_frequency__ = max_frequency + self.__epoc__ = 0 self.__random__ = random.Random(RANDOM_SEED) # load vocabulary (dictionary) self.__vocab_dict__, self.__vocab_list__ = \ @@ -245,9 +247,33 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch + def __batch_shuffle__(self, manifest, batch_size): + """ + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly remove `k` instances in order to make different mini-batches, + then make minibatches and each minibatch size is batch_size. + 4. Shuffle the minibatches. + + :param manifest: manifest file. + :type manifest: list + :param batch_size: batch size. + :type batch_size: int + """ + manifest.sort(key=lambda x: x["duration"]) + shift_len = self.__random__.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + self.__random__.shuffle(batch_manifest) + batch_manifest = list(sum(batch_manifest, ())) + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest + def instance_reader_creator(self, manifest_path, - sort_by_duration=True, + batch_size, + sortagrad=True, shuffle=False): """ Instance reader creator for audio data. Creat a callable function to @@ -258,18 +284,14 @@ class DataGenerator(object): :param manifest_path: Filepath of manifest for audio clip files. :type manifest_path: basestring - :param sort_by_duration: Sort the audio clips by duration if set True - (for SortaGrad). - :type sort_by_duration: bool + :param sortagrad: Sort the audio clips by duration in the first epoc + if set True. + :type sortagrad: bool :param shuffle: Shuffle the audio clips if set True. :type shuffle: bool :return: Data reader function. :rtype: callable """ - if sort_by_duration and shuffle: - sort_by_duration = False - logger.warn("When shuffle set to true, " - "sort_by_duration is forced to set False.") def reader(): # read manifest @@ -278,16 +300,17 @@ class DataGenerator(object): max_duration=self.__max_duration__, min_duration=self.__min_duration__) # sort (by duration) or shuffle manifest - if sort_by_duration: + if self.__epoc__ == 0 and sortagrad: manifest.sort(key=lambda x: x["duration"]) - if shuffle: - self.__random__.shuffle(manifest) + elif shuffle: + manifest = self.__batch_shuffle__(manifest, batch_size) # extract spectrogram feature for instance in manifest: spectrogram = self.__audio_featurize__( instance["audio_filepath"]) transcript = self.__text_featurize__(instance["text"]) yield (spectrogram, transcript) + self.__epoc__ += 1 return reader @@ -296,7 +319,7 @@ class DataGenerator(object): batch_size, padding_to=-1, flatten=False, - sort_by_duration=True, + sortagrad=False, shuffle=False): """ Batch data reader creator for audio data. Creat a callable function to @@ -317,9 +340,9 @@ class DataGenerator(object): :param flatten: If set True, audio data will be flatten to be a 1-dim ndarray. Otherwise, 2-dim ndarray. Default is False. :type flatten: bool - :param sort_by_duration: Sort the audio clips by duration if set True - (for SortaGrad). - :type sort_by_duration: bool + :param sortagrad: Sort the audio clips by duration in the first epoc + if set True. + :type sortagrad: bool :param shuffle: Shuffle the audio clips if set True. :type shuffle: bool :return: Batch reader function, producing batches of data when called. @@ -329,7 +352,8 @@ class DataGenerator(object): def batch_reader(): instance_reader = self.instance_reader_creator( manifest_path=manifest_path, - sort_by_duration=sort_by_duration, + batch_size=batch_size, + sortagrad=sortagrad, shuffle=shuffle) batch = [] for instance in instance_reader(): diff --git a/train.py b/train.py index e6a7d076..55577b0d 100644 --- a/train.py +++ b/train.py @@ -85,23 +85,27 @@ def train(): """ DeepSpeech2 training. """ + # initialize data generator - data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + def data_generator(): + return DataGenerator( + vocab_filepath=args.vocab_filepath, + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + train_generator = data_generator() + test_generator = data_generator() # create network config - dict_size = data_generator.vocabulary_size() + dict_size = train_generator.vocabulary_size() + # paddle.data_type.dense_array is used for variable batch input. + # the size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be set at each batch. audio_data = paddle.layer.data( - name="audio_spectrogram", - height=161, - width=2000, - type=paddle.data_type.dense_vector(322000)) + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", type=paddle.data_type.integer_value_sequence(dict_size)) @@ -122,28 +126,16 @@ def train(): cost=cost, parameters=parameters, update_equation=optimizer) # prepare data reader - train_batch_reader_sortagrad = data_generator.batch_reader_creator( - manifest_path=args.train_manifest_path, - batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=True, - shuffle=False) - train_batch_reader_nosortagrad = data_generator.batch_reader_creator( + train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=False, + sortagrad=True, shuffle=True) - test_batch_reader = data_generator.batch_reader_creator( + test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=False, shuffle=False) - feeding = data_generator.data_name_feeding() + feeding = train_generator.data_name_feeding() # create event handler def event_handler(event): @@ -169,17 +161,8 @@ def train(): time.time() - start_time, event.pass_id, result.cost) # run train - # first pass with sortagrad - if args.use_sortagrad: - trainer.train( - reader=train_batch_reader_sortagrad, - event_handler=event_handler, - num_passes=1, - feeding=feeding) - args.num_passes -= 1 - # other passes without sortagrad trainer.train( - reader=train_batch_reader_nosortagrad, + reader=train_batch_reader, event_handler=event_handler, num_passes=args.num_passes, feeding=feeding) -- GitLab