diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 48e03fe85d70b61686d189110154c42df1374f91..424343a48ffa579a8ab465794987f957de36abdb 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -80,7 +80,7 @@ class DataGenerator(object): padding_to=-1, flatten=False, sortagrad=False, - batch_shuffle=False): + shuffle_method="batch_shuffle"): """ Batch data reader creator for audio data. Return a callable generator function to produce batches of data. @@ -104,12 +104,22 @@ class DataGenerator(object): :param sortagrad: If set True, sort the instances by audio duration in the first epoch for speed up training. :type sortagrad: bool - :param batch_shuffle: If set True, instances are batch-wise shuffled. - For more details, please see - ``_batch_shuffle.__doc__``. - If sortagrad is True, batch_shuffle is disabled + :param shuffle_method: Shuffle method. Options: + '' or None: no shuffle. + 'instance_shuffle': instance-wise shuffle. + 'batch_shuffle': similarly-sized instances are + put into batches, and then + batch-wise shuffle the batches. + For more details, please see + ``_batch_shuffle.__doc__``. + 'batch_shuffle_clipped': 'batch_shuffle' with + head shift and tail + clipping. For more + details, please see + ``_batch_shuffle``. + If sortagrad is True, shuffle is disabled for the first epoch. - :type batch_shuffle: bool + :type shuffle_method: None|str :return: Batch reader function, producing batches of data when called. :rtype: callable """ @@ -123,8 +133,20 @@ class DataGenerator(object): # sort (by duration) or batch-wise shuffle the manifest if self._epoch == 0 and sortagrad: manifest.sort(key=lambda x: x["duration"]) - elif batch_shuffle: - manifest = self._batch_shuffle(manifest, batch_size) + else: + if shuffle_method == "batch_shuffle": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=False) + elif shuffle_method == "batch_shuffle_clipped": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=True) + elif shuffle_method == "instance_shuffle": + self._rng.shuffle(manifest) + elif not shuffle_method: + pass + else: + raise ValueError("Unknown shuffle method %s." % + shuffle_method) # prepare batches instance_reader = self._instance_reader_creator(manifest) batch = [] @@ -218,7 +240,7 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch - def _batch_shuffle(self, manifest, batch_size): + def _batch_shuffle(self, manifest, batch_size, clipped=False): """Put similarly-sized instances into minibatches for better efficiency and make a batch-wise shuffle. @@ -233,6 +255,9 @@ class DataGenerator(object): :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool :return: Batch shuffled mainifest. :rtype: list """ @@ -241,7 +266,8 @@ class DataGenerator(object): batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) self._rng.shuffle(batch_manifest) batch_manifest = list(sum(batch_manifest, ())) - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) + if not clipped: + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) return batch_manifest diff --git a/deep_speech_2/datasets/librispeech/librispeech.py b/deep_speech_2/datasets/librispeech/librispeech.py index faf038cc1919e3659e39d2f06b58816f3b72ba12..87e52ae4aa286503d79f1326065831acfe6bf985 100644 --- a/deep_speech_2/datasets/librispeech/librispeech.py +++ b/deep_speech_2/datasets/librispeech/librispeech.py @@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708" -parser = argparse.ArgumentParser( - description='Downloads and prepare LibriSpeech dataset.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--target_dir", default=DATA_HOME + "/Libri", diff --git a/deep_speech_2/decoder.py b/deep_speech_2/decoder.py index 8314885ce609f4e3da6814cc5831f2e1dd2029ff..77d950b8db072d539788fd1b2bc7ac0525ffa0f9 100644 --- a/deep_speech_2/decoder.py +++ b/deep_speech_2/decoder.py @@ -8,8 +8,7 @@ from itertools import groupby def ctc_best_path_decode(probs_seq, vocabulary): - """ - Best path decoding, also called argmax decoding or greedy decoding. + """Best path decoding, also called argmax decoding or greedy decoding. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary): def ctc_decode(probs_seq, vocabulary, method): - """ - CTC-like sequence decoding from a sequence of likelihood probablilites. + """CTC-like sequence decoding from a sequence of likelihood probablilites. :param probs_seq: 2-D list of probabilities over the vocabulary for each character. Each element is a list of float probabilities diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index f7c99df117985058eede89301b6339bbaf4f46c2..06449ab05c7960ec78acc9ce5bb664cf1058a845 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -10,9 +10,9 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import ctc_decode +import utils -parser = argparse.ArgumentParser( - description='Simplified version of DeepSpeech2 inference.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", default=10, @@ -62,9 +62,7 @@ args = parser.parse_args() def infer(): - """ - Max-ctc-decoding for DeepSpeech2. - """ + """Max-ctc-decoding for DeepSpeech2.""" # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, @@ -98,7 +96,7 @@ def infer(): manifest_path=args.decode_manifest_path, batch_size=args.num_samples, sortagrad=False, - batch_shuffle=False) + shuffle_method=None) infer_data = batch_reader().next() # run inference @@ -123,6 +121,7 @@ def infer(): def main(): + utils.print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=1) infer() diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 6074aa358d84b63b7b5191aab775b5d94c04dd52..c60a039b69d91a89eb20e83ec1e090c8600d47a3 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -12,6 +12,7 @@ import distutils.util import paddle.v2 as paddle from model import deep_speech2 from data_utils.data import DataGenerator +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -51,6 +52,12 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--shuffle_method", + default='instance_shuffle', + type=str, + help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " + "'batch_shuffle_batch'. (default: %(default)s)") parser.add_argument( "--trainer_count", default=4, @@ -93,9 +100,7 @@ args = parser.parse_args() def train(): - """ - DeepSpeech2 training. - """ + """DeepSpeech2 training.""" # initialize data generator def data_generator(): @@ -145,13 +150,13 @@ def train(): batch_size=args.batch_size, min_batch_size=args.trainer_count, sortagrad=args.use_sortagrad if args.init_model_path is None else False, - batch_shuffle=True) + shuffle_method=args.shuffle_method) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, min_batch_size=1, # must be 1, but will have errors. sortagrad=False, - batch_shuffle=False) + shuffle_method=None) # create event handler def event_handler(event): @@ -186,6 +191,7 @@ def train(): def main(): + utils.print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train() diff --git a/deep_speech_2/utils.py b/deep_speech_2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca363c8f59c2b1cd2885db4b04605c0025998bf --- /dev/null +++ b/deep_speech_2/utils.py @@ -0,0 +1,25 @@ +"""Contains common utility functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----- Configuration Arguments -----") + for arg, value in vars(args).iteritems(): + print("%s: %s" % (arg, value)) + print("------------------------------------")