提交 f545367c 编写于 作者: X Xinghai Sun

Add shuffle type of instance_shuffle and batch_shuffle_clipped.

上级 b72aec53
...@@ -80,7 +80,7 @@ class DataGenerator(object): ...@@ -80,7 +80,7 @@ class DataGenerator(object):
padding_to=-1, padding_to=-1,
flatten=False, flatten=False,
sortagrad=False, sortagrad=False,
batch_shuffle=False): shuffle_method="batch_shuffle"):
""" """
Batch data reader creator for audio data. Return a callable generator Batch data reader creator for audio data. Return a callable generator
function to produce batches of data. function to produce batches of data.
...@@ -104,12 +104,22 @@ class DataGenerator(object): ...@@ -104,12 +104,22 @@ class DataGenerator(object):
:param sortagrad: If set True, sort the instances by audio duration :param sortagrad: If set True, sort the instances by audio duration
in the first epoch for speed up training. in the first epoch for speed up training.
:type sortagrad: bool :type sortagrad: bool
:param batch_shuffle: If set True, instances are batch-wise shuffled. :param shuffle_method: Shuffle method. Options:
For more details, please see '' or None: no shuffle.
``_batch_shuffle.__doc__``. 'instance_shuffle': instance-wise shuffle.
If sortagrad is True, batch_shuffle is disabled 'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
``_batch_shuffle.__doc__``.
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch. for the first epoch.
:type batch_shuffle: bool :type shuffle_method: None|str
:return: Batch reader function, producing batches of data when called. :return: Batch reader function, producing batches of data when called.
:rtype: callable :rtype: callable
""" """
...@@ -123,8 +133,20 @@ class DataGenerator(object): ...@@ -123,8 +133,20 @@ class DataGenerator(object):
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad: if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle: else:
manifest = self._batch_shuffle(manifest, batch_size) if shuffle_method == "batch_shuffle":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=False)
elif shuffle_method == "batch_shuffle_clipped":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest)
elif not shuffle_method:
pass
else:
raise ValueError("Unknown shuffle method %s." %
shuffle_method)
# prepare batches # prepare batches
instance_reader = self._instance_reader_creator(manifest) instance_reader = self._instance_reader_creator(manifest)
batch = [] batch = []
...@@ -218,7 +240,7 @@ class DataGenerator(object): ...@@ -218,7 +240,7 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def _batch_shuffle(self, manifest, batch_size): def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency """Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle. and make a batch-wise shuffle.
...@@ -233,6 +255,9 @@ class DataGenerator(object): ...@@ -233,6 +255,9 @@ class DataGenerator(object):
:param batch_size: Batch size. This size is also used for generate :param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle. a random number for batch shuffle.
:type batch_size: int :type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest. :return: Batch shuffled mainifest.
:rtype: list :rtype: list
""" """
...@@ -241,7 +266,8 @@ class DataGenerator(object): ...@@ -241,7 +266,8 @@ class DataGenerator(object):
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self._rng.shuffle(batch_manifest) self._rng.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ())) batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest) if not clipped:
batch_manifest.extend(manifest[-res_len:]) res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[0:shift_len]) batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest return batch_manifest
...@@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" ...@@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708" MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=__doc__)
description='Downloads and prepare LibriSpeech dataset.')
parser.add_argument( parser.add_argument(
"--target_dir", "--target_dir",
default=DATA_HOME + "/Libri", default=DATA_HOME + "/Libri",
......
...@@ -8,8 +8,7 @@ from itertools import groupby ...@@ -8,8 +8,7 @@ from itertools import groupby
def ctc_best_path_decode(probs_seq, vocabulary): def ctc_best_path_decode(probs_seq, vocabulary):
""" """Best path decoding, also called argmax decoding or greedy decoding.
Best path decoding, also called argmax decoding or greedy decoding.
Path consisting of the most probable tokens are further post-processed to Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks. remove consecutive repetitions and all blanks.
...@@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary): ...@@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def ctc_decode(probs_seq, vocabulary, method): def ctc_decode(probs_seq, vocabulary, method):
""" """CTC-like sequence decoding from a sequence of likelihood probablilites.
CTC-like sequence decoding from a sequence of likelihood probablilites.
:param probs_seq: 2-D list of probabilities over the vocabulary for each :param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities character. Each element is a list of float probabilities
......
...@@ -10,9 +10,9 @@ import paddle.v2 as paddle ...@@ -10,9 +10,9 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import ctc_decode from decoder import ctc_decode
import utils
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=__doc__)
description='Simplified version of DeepSpeech2 inference.')
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
default=10, default=10,
...@@ -62,9 +62,7 @@ args = parser.parse_args() ...@@ -62,9 +62,7 @@ args = parser.parse_args()
def infer(): def infer():
""" """Max-ctc-decoding for DeepSpeech2."""
Max-ctc-decoding for DeepSpeech2.
"""
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
...@@ -98,7 +96,7 @@ def infer(): ...@@ -98,7 +96,7 @@ def infer():
manifest_path=args.decode_manifest_path, manifest_path=args.decode_manifest_path,
batch_size=args.num_samples, batch_size=args.num_samples,
sortagrad=False, sortagrad=False,
batch_shuffle=False) shuffle_method=None)
infer_data = batch_reader().next() infer_data = batch_reader().next()
# run inference # run inference
...@@ -123,6 +121,7 @@ def infer(): ...@@ -123,6 +121,7 @@ def infer():
def main(): def main():
utils.print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=1)
infer() infer()
......
...@@ -12,6 +12,7 @@ import distutils.util ...@@ -12,6 +12,7 @@ import distutils.util
import paddle.v2 as paddle import paddle.v2 as paddle
from model import deep_speech2 from model import deep_speech2
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
...@@ -51,6 +52,12 @@ parser.add_argument( ...@@ -51,6 +52,12 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--shuffle_method",
default='instance_shuffle',
type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--trainer_count", "--trainer_count",
default=4, default=4,
...@@ -93,9 +100,7 @@ args = parser.parse_args() ...@@ -93,9 +100,7 @@ args = parser.parse_args()
def train(): def train():
""" """DeepSpeech2 training."""
DeepSpeech2 training.
"""
# initialize data generator # initialize data generator
def data_generator(): def data_generator():
...@@ -145,13 +150,13 @@ def train(): ...@@ -145,13 +150,13 @@ def train():
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=args.trainer_count, min_batch_size=args.trainer_count,
sortagrad=args.use_sortagrad if args.init_model_path is None else False, sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True) shuffle_method=args.shuffle_method)
test_batch_reader = test_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
min_batch_size=1, # must be 1, but will have errors. min_batch_size=1, # must be 1, but will have errors.
sortagrad=False, sortagrad=False,
batch_shuffle=False) shuffle_method=None)
# create event handler # create event handler
def event_handler(event): def event_handler(event):
...@@ -186,6 +191,7 @@ def train(): ...@@ -186,6 +191,7 @@ def train():
def main(): def main():
utils.print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train() train()
......
"""Contains common utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----- Configuration Arguments -----")
for arg, value in vars(args).iteritems():
print("%s: %s" % (arg, value))
print("------------------------------------")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册