提交 9c27b1d1 编写于 作者: D dangqingqing

add more comments and update train.py

上级 bf735400
......@@ -247,25 +247,34 @@ class DataGenerator(object):
new_batch.append((padded_audio, text))
return new_batch
def __batch_shuffle__(self, manifest, batch_shuffle_size):
def __batch_shuffle__(self, manifest, batch_size):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_shuffle_size).
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_shuffle_size.
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_shuffle_size: This size is uesed to generate a random number,
it usually equals to batch size.
:type batch_shuffle_size: int
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_shuffle_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size)
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
......@@ -327,8 +336,9 @@ class DataGenerator(object):
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle,
but a specific batch-wise shuffle.
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
......
......@@ -143,12 +143,12 @@ def train():
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
sortagrad=True,
shuffle=True)
sortagrad=True if args.init_model_path is None else False,
batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
shuffle=False)
batch_shuffle=False)
feeding = train_generator.data_name_feeding()
# create event handler
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册