diff --git a/audio_data_utils.py b/audio_data_utils.py index abb7f1e9931543349c93d9310d48638b5baf74af..692a42809f790845e8d3349ccce786d9fddce8cd 100644 --- a/audio_data_utils.py +++ b/audio_data_utils.py @@ -247,22 +247,25 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch - def __batch_shuffle__(self, manifest, batch_size): + def __batch_shuffle__(self, manifest, batch_shuffle_size): """ 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). + 2. Generate a random number `k`, k in [0, batch_shuffle_size). 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_size. + then make minibatches and each minibatch size is batch_shuffle_size. 4. Shuffle the minibatches. :param manifest: manifest file. :type manifest: list - :param batch_size: batch size. - :type batch_size: int + :param batch_shuffle_size: This size is uesed to generate a random number, + it usually equals to batch size. + :type batch_shuffle_size: int + :return: batch shuffled mainifest. + :rtype: list """ manifest.sort(key=lambda x: x["duration"]) - shift_len = self.__random__.randint(0, batch_size - 1) - batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + shift_len = self.__random__.randint(0, batch_shuffle_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size) self.__random__.shuffle(batch_manifest) batch_manifest = list(sum(batch_manifest, ())) res_len = len(manifest) - shift_len - len(batch_manifest) @@ -270,11 +273,7 @@ class DataGenerator(object): batch_manifest.extend(manifest[0:shift_len]) return batch_manifest - def instance_reader_creator(self, - manifest_path, - batch_size, - sortagrad=True, - shuffle=False): + def instance_reader_creator(self, manifest): """ Instance reader creator for audio data. Creat a callable function to produce instances of data. @@ -282,35 +281,19 @@ class DataGenerator(object): Instance: a tuple of a numpy ndarray of audio spectrogram and a list of tokenized and indexed transcription text. - :param manifest_path: Filepath of manifest for audio clip files. - :type manifest_path: basestring - :param sortagrad: Sort the audio clips by duration in the first epoc - if set True. - :type sortagrad: bool - :param shuffle: Shuffle the audio clips if set True. - :type shuffle: bool + :param manifest: Filepath of manifest for audio clip files. + :type manifest: basestring :return: Data reader function. :rtype: callable """ def reader(): - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - # sort (by duration) or shuffle manifest - if self.__epoc__ == 0 and sortagrad: - manifest.sort(key=lambda x: x["duration"]) - elif shuffle: - manifest = self.__batch_shuffle__(manifest, batch_size) # extract spectrogram feature for instance in manifest: spectrogram = self.__audio_featurize__( instance["audio_filepath"]) transcript = self.__text_featurize__(instance["text"]) yield (spectrogram, transcript) - self.__epoc__ += 1 return reader @@ -320,7 +303,7 @@ class DataGenerator(object): padding_to=-1, flatten=False, sortagrad=False, - shuffle=False): + batch_shuffle=False): """ Batch data reader creator for audio data. Creat a callable function to produce batches of data. @@ -343,18 +326,28 @@ class DataGenerator(object): :param sortagrad: Sort the audio clips by duration in the first epoc if set True. :type sortagrad: bool - :param shuffle: Shuffle the audio clips if set True. - :type shuffle: bool + :param batch_shuffle: Shuffle the audio clips if set True. It is + not a thorough instance-wise shuffle, + but a specific batch-wise shuffle. + :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable """ def batch_reader(): - instance_reader = self.instance_reader_creator( + # read manifest + manifest = self.__read_manifest__( manifest_path=manifest_path, - batch_size=batch_size, - sortagrad=sortagrad, - shuffle=shuffle) + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + + # sort (by duration) or shuffle manifest + if self.__epoc__ == 0 and sortagrad: + manifest.sort(key=lambda x: x["duration"]) + elif batch_shuffle: + manifest = self.__batch_shuffle__(manifest, batch_size) + + instance_reader = self.instance_reader_creator(manifest) batch = [] for instance in instance_reader(): batch.append(instance) @@ -363,6 +356,7 @@ class DataGenerator(object): batch = [] if len(batch) > 0: yield self.__padding_batch__(batch, padding_to, flatten) + self.__epoc__ += 1 return batch_reader