提交 c25c62b8 编写于 作者: D dangqingqing

refine audio_data_utils.py

上级 7c85e0fd
...@@ -247,22 +247,25 @@ class DataGenerator(object): ...@@ -247,22 +247,25 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def __batch_shuffle__(self, manifest, batch_size): def __batch_shuffle__(self, manifest, batch_shuffle_size):
""" """
1. Sort the audio clips by duration. 1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size). 2. Generate a random number `k`, k in [0, batch_shuffle_size).
3. Randomly remove `k` instances in order to make different mini-batches, 3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size. then make minibatches and each minibatch size is batch_shuffle_size.
4. Shuffle the minibatches. 4. Shuffle the minibatches.
:param manifest: manifest file. :param manifest: manifest file.
:type manifest: list :type manifest: list
:param batch_size: batch size. :param batch_shuffle_size: This size is uesed to generate a random number,
:type batch_size: int it usually equals to batch size.
:type batch_shuffle_size: int
:return: batch shuffled mainifest.
:rtype: list
""" """
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1) shift_len = self.__random__.randint(0, batch_shuffle_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size)
self.__random__.shuffle(batch_manifest) self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ())) batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest) res_len = len(manifest) - shift_len - len(batch_manifest)
...@@ -270,11 +273,7 @@ class DataGenerator(object): ...@@ -270,11 +273,7 @@ class DataGenerator(object):
batch_manifest.extend(manifest[0:shift_len]) batch_manifest.extend(manifest[0:shift_len])
return batch_manifest return batch_manifest
def instance_reader_creator(self, def instance_reader_creator(self, manifest):
manifest_path,
batch_size,
sortagrad=True,
shuffle=False):
""" """
Instance reader creator for audio data. Creat a callable function to Instance reader creator for audio data. Creat a callable function to
produce instances of data. produce instances of data.
...@@ -282,35 +281,19 @@ class DataGenerator(object): ...@@ -282,35 +281,19 @@ class DataGenerator(object):
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text. tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files. :param manifest: Filepath of manifest for audio clip files.
:type manifest_path: basestring :type manifest: basestring
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Data reader function. :return: Data reader function.
:rtype: callable :rtype: callable
""" """
def reader(): def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
# extract spectrogram feature # extract spectrogram feature
for instance in manifest: for instance in manifest:
spectrogram = self.__audio_featurize__( spectrogram = self.__audio_featurize__(
instance["audio_filepath"]) instance["audio_filepath"])
transcript = self.__text_featurize__(instance["text"]) transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript) yield (spectrogram, transcript)
self.__epoc__ += 1
return reader return reader
...@@ -320,7 +303,7 @@ class DataGenerator(object): ...@@ -320,7 +303,7 @@ class DataGenerator(object):
padding_to=-1, padding_to=-1,
flatten=False, flatten=False,
sortagrad=False, sortagrad=False,
shuffle=False): batch_shuffle=False):
""" """
Batch data reader creator for audio data. Creat a callable function to Batch data reader creator for audio data. Creat a callable function to
produce batches of data. produce batches of data.
...@@ -343,18 +326,28 @@ class DataGenerator(object): ...@@ -343,18 +326,28 @@ class DataGenerator(object):
:param sortagrad: Sort the audio clips by duration in the first epoc :param sortagrad: Sort the audio clips by duration in the first epoc
if set True. if set True.
:type sortagrad: bool :type sortagrad: bool
:param shuffle: Shuffle the audio clips if set True. :param batch_shuffle: Shuffle the audio clips if set True. It is
:type shuffle: bool not a thorough instance-wise shuffle,
but a specific batch-wise shuffle.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called. :return: Batch reader function, producing batches of data when called.
:rtype: callable :rtype: callable
""" """
def batch_reader(): def batch_reader():
instance_reader = self.instance_reader_creator( # read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path, manifest_path=manifest_path,
batch_size=batch_size, max_duration=self.__max_duration__,
sortagrad=sortagrad, min_duration=self.__min_duration__)
shuffle=shuffle)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
instance_reader = self.instance_reader_creator(manifest)
batch = [] batch = []
for instance in instance_reader(): for instance in instance_reader():
batch.append(instance) batch.append(instance)
...@@ -363,6 +356,7 @@ class DataGenerator(object): ...@@ -363,6 +356,7 @@ class DataGenerator(object):
batch = [] batch = []
if len(batch) > 0: if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten) yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1
return batch_reader return batch_reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册