From 9c27b1d14e601ff64df6e5dacc95d77933e2b39a Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 12 Jun 2017 19:53:41 +0800 Subject: [PATCH] add more comments and update train.py --- audio_data_utils.py | 30 ++++++++++++++++++++---------- train.py | 6 +++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/audio_data_utils.py b/audio_data_utils.py index 692a4280..1cd29be1 100644 --- a/audio_data_utils.py +++ b/audio_data_utils.py @@ -247,25 +247,34 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch - def __batch_shuffle__(self, manifest, batch_shuffle_size): + def __batch_shuffle__(self, manifest, batch_size): """ + The instances have different lengths and they cannot be + combined into a single matrix multiplication. It usually + sorts the training examples by length and combines only + similarly-sized instances into minibatches, pads with + silence when necessary so that all instances in a batch + have the same length. This batch shuffle fuction is used + to make similarly-sized instances into minibatches and + make a batch-wise shuffle. + 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_shuffle_size). + 2. Generate a random number `k`, k in [0, batch_size). 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_shuffle_size. + then make minibatches and each minibatch size is batch_size. 4. Shuffle the minibatches. :param manifest: manifest file. :type manifest: list - :param batch_shuffle_size: This size is uesed to generate a random number, - it usually equals to batch size. - :type batch_shuffle_size: int + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int :return: batch shuffled mainifest. :rtype: list """ manifest.sort(key=lambda x: x["duration"]) - shift_len = self.__random__.randint(0, batch_shuffle_size - 1) - batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size) + shift_len = self.__random__.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) self.__random__.shuffle(batch_manifest) batch_manifest = list(sum(batch_manifest, ())) res_len = len(manifest) - shift_len - len(batch_manifest) @@ -327,8 +336,9 @@ class DataGenerator(object): if set True. :type sortagrad: bool :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, - but a specific batch-wise shuffle. + not a thorough instance-wise shuffle, but a + specific batch-wise shuffle. For more details, + please see `__batch_shuffle__` function. :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable diff --git a/train.py b/train.py index eb9b56de..957c2426 100644 --- a/train.py +++ b/train.py @@ -143,12 +143,12 @@ def train(): train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - sortagrad=True, - shuffle=True) + sortagrad=True if args.init_model_path is None else False, + batch_shuffle=True) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, - shuffle=False) + batch_shuffle=False) feeding = train_generator.data_name_feeding() # create event handler -- GitLab