diff --git a/deep_speech_2/audio_data_utils.py b/deep_speech_2/audio_data_utils.py index 692a42809f790845e8d3349ccce786d9fddce8cd..1cd29be114a416636db8c2d7e888d0d8d6c2a8a8 100644 --- a/deep_speech_2/audio_data_utils.py +++ b/deep_speech_2/audio_data_utils.py @@ -247,25 +247,34 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch - def __batch_shuffle__(self, manifest, batch_shuffle_size): + def __batch_shuffle__(self, manifest, batch_size): """ + The instances have different lengths and they cannot be + combined into a single matrix multiplication. It usually + sorts the training examples by length and combines only + similarly-sized instances into minibatches, pads with + silence when necessary so that all instances in a batch + have the same length. This batch shuffle fuction is used + to make similarly-sized instances into minibatches and + make a batch-wise shuffle. + 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_shuffle_size). + 2. Generate a random number `k`, k in [0, batch_size). 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_shuffle_size. + then make minibatches and each minibatch size is batch_size. 4. Shuffle the minibatches. :param manifest: manifest file. :type manifest: list - :param batch_shuffle_size: This size is uesed to generate a random number, - it usually equals to batch size. - :type batch_shuffle_size: int + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int :return: batch shuffled mainifest. :rtype: list """ manifest.sort(key=lambda x: x["duration"]) - shift_len = self.__random__.randint(0, batch_shuffle_size - 1) - batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size) + shift_len = self.__random__.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) self.__random__.shuffle(batch_manifest) batch_manifest = list(sum(batch_manifest, ())) res_len = len(manifest) - shift_len - len(batch_manifest) @@ -327,8 +336,9 @@ class DataGenerator(object): if set True. :type sortagrad: bool :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, - but a specific batch-wise shuffle. + not a thorough instance-wise shuffle, but a + specific batch-wise shuffle. For more details, + please see `__batch_shuffle__` function. :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index eb9b56de7f325a507c00239b38b8bdb1dd985906..957c24267ce24c917ca8437683d03eefec6636d5 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -143,12 +143,12 @@ def train(): train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - sortagrad=True, - shuffle=True) + sortagrad=True if args.init_model_path is None else False, + batch_shuffle=True) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, - shuffle=False) + batch_shuffle=False) feeding = train_generator.data_name_feeding() # create event handler