提交 1a319fbf 编写于 作者: D dangqingqing

Support variable input batch and sortagrad.

上级 72874de9
...@@ -8,6 +8,7 @@ import json ...@@ -8,6 +8,7 @@ import json
import random import random
import soundfile import soundfile
import numpy as np import numpy as np
import itertools
import os import os
RANDOM_SEED = 0 RANDOM_SEED = 0
...@@ -62,6 +63,7 @@ class DataGenerator(object): ...@@ -62,6 +63,7 @@ class DataGenerator(object):
self.__stride_ms__ = stride_ms self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED) self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary) # load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \ self.__vocab_dict__, self.__vocab_list__ = \
...@@ -245,9 +247,33 @@ class DataGenerator(object): ...@@ -245,9 +247,33 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def __batch_shuffle__(self, manifest, batch_size):
"""
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: batch size.
:type batch_size: int
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def instance_reader_creator(self, def instance_reader_creator(self,
manifest_path, manifest_path,
sort_by_duration=True, batch_size,
sortagrad=True,
shuffle=False): shuffle=False):
""" """
Instance reader creator for audio data. Creat a callable function to Instance reader creator for audio data. Creat a callable function to
...@@ -258,18 +284,14 @@ class DataGenerator(object): ...@@ -258,18 +284,14 @@ class DataGenerator(object):
:param manifest_path: Filepath of manifest for audio clip files. :param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring :type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True :param sortagrad: Sort the audio clips by duration in the first epoc
(for SortaGrad). if set True.
:type sort_by_duration: bool :type sortagrad: bool
:param shuffle: Shuffle the audio clips if set True. :param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool :type shuffle: bool
:return: Data reader function. :return: Data reader function.
:rtype: callable :rtype: callable
""" """
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")
def reader(): def reader():
# read manifest # read manifest
...@@ -278,16 +300,17 @@ class DataGenerator(object): ...@@ -278,16 +300,17 @@ class DataGenerator(object):
max_duration=self.__max_duration__, max_duration=self.__max_duration__,
min_duration=self.__min_duration__) min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest # sort (by duration) or shuffle manifest
if sort_by_duration: if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
if shuffle: elif shuffle:
self.__random__.shuffle(manifest) manifest = self.__batch_shuffle__(manifest, batch_size)
# extract spectrogram feature # extract spectrogram feature
for instance in manifest: for instance in manifest:
spectrogram = self.__audio_featurize__( spectrogram = self.__audio_featurize__(
instance["audio_filepath"]) instance["audio_filepath"])
transcript = self.__text_featurize__(instance["text"]) transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript) yield (spectrogram, transcript)
self.__epoc__ += 1
return reader return reader
...@@ -296,7 +319,7 @@ class DataGenerator(object): ...@@ -296,7 +319,7 @@ class DataGenerator(object):
batch_size, batch_size,
padding_to=-1, padding_to=-1,
flatten=False, flatten=False,
sort_by_duration=True, sortagrad=False,
shuffle=False): shuffle=False):
""" """
Batch data reader creator for audio data. Creat a callable function to Batch data reader creator for audio data. Creat a callable function to
...@@ -317,9 +340,9 @@ class DataGenerator(object): ...@@ -317,9 +340,9 @@ class DataGenerator(object):
:param flatten: If set True, audio data will be flatten to be a 1-dim :param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False. ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool :type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True :param sortagrad: Sort the audio clips by duration in the first epoc
(for SortaGrad). if set True.
:type sort_by_duration: bool :type sortagrad: bool
:param shuffle: Shuffle the audio clips if set True. :param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool :type shuffle: bool
:return: Batch reader function, producing batches of data when called. :return: Batch reader function, producing batches of data when called.
...@@ -329,7 +352,8 @@ class DataGenerator(object): ...@@ -329,7 +352,8 @@ class DataGenerator(object):
def batch_reader(): def batch_reader():
instance_reader = self.instance_reader_creator( instance_reader = self.instance_reader_creator(
manifest_path=manifest_path, manifest_path=manifest_path,
sort_by_duration=sort_by_duration, batch_size=batch_size,
sortagrad=sortagrad,
shuffle=shuffle) shuffle=shuffle)
batch = [] batch = []
for instance in instance_reader(): for instance in instance_reader():
......
...@@ -85,8 +85,10 @@ def train(): ...@@ -85,8 +85,10 @@ def train():
""" """
DeepSpeech2 training. DeepSpeech2 training.
""" """
# initialize data generator # initialize data generator
data_generator = DataGenerator( def data_generator():
return DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path, normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200, normalizer_num_samples=200,
...@@ -95,13 +97,15 @@ def train(): ...@@ -95,13 +97,15 @@ def train():
stride_ms=10, stride_ms=10,
window_ms=20) window_ms=20)
train_generator = data_generator()
test_generator = data_generator()
# create network config # create network config
dict_size = data_generator.vocabulary_size() dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
...@@ -122,28 +126,16 @@ def train(): ...@@ -122,28 +126,16 @@ def train():
cost=cost, parameters=parameters, update_equation=optimizer) cost=cost, parameters=parameters, update_equation=optimizer)
# prepare data reader # prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
padding_to=2000, sortagrad=True,
flatten=True,
sort_by_duration=False,
shuffle=True) shuffle=True)
test_batch_reader = data_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False) shuffle=False)
feeding = data_generator.data_name_feeding() feeding = train_generator.data_name_feeding()
# create event handler # create event handler
def event_handler(event): def event_handler(event):
...@@ -169,17 +161,8 @@ def train(): ...@@ -169,17 +161,8 @@ def train():
time.time() - start_time, event.pass_id, result.cost) time.time() - start_time, event.pass_id, result.cost)
# run train # run train
# first pass with sortagrad
if args.use_sortagrad:
trainer.train(
reader=train_batch_reader_sortagrad,
event_handler=event_handler,
num_passes=1,
feeding=feeding)
args.num_passes -= 1
# other passes without sortagrad
trainer.train( trainer.train(
reader=train_batch_reader_nosortagrad, reader=train_batch_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes, num_passes=args.num_passes,
feeding=feeding) feeding=feeding)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册