提交 0a307ac6 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #101 from xinghai-sun/ds2_improve

Add shuffle type of instance_shuffle and batch_shuffle_clipped.
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -80,7 +80,7 @@ class DataGenerator(object):
padding_to=-1,
flatten=False,
sortagrad=False,
batch_shuffle=False):
shuffle_method="batch_shuffle"):
"""
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
......@@ -104,12 +104,22 @@ class DataGenerator(object):
:param sortagrad: If set True, sort the instances by audio duration
in the first epoch for speed up training.
:type sortagrad: bool
:param batch_shuffle: If set True, instances are batch-wise shuffled.
For more details, please see
``_batch_shuffle.__doc__``.
If sortagrad is True, batch_shuffle is disabled
:param shuffle_method: Shuffle method. Options:
'' or None: no shuffle.
'instance_shuffle': instance-wise shuffle.
'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
``_batch_shuffle.__doc__``.
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch.
:type batch_shuffle: bool
:type shuffle_method: None|str
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
......@@ -123,8 +133,20 @@ class DataGenerator(object):
# sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self._batch_shuffle(manifest, batch_size)
else:
if shuffle_method == "batch_shuffle":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=False)
elif shuffle_method == "batch_shuffle_clipped":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest)
elif not shuffle_method:
pass
else:
raise ValueError("Unknown shuffle method %s." %
shuffle_method)
# prepare batches
instance_reader = self._instance_reader_creator(manifest)
batch = []
......@@ -218,7 +240,7 @@ class DataGenerator(object):
new_batch.append((padded_audio, text))
return new_batch
def _batch_shuffle(self, manifest, batch_size):
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
......@@ -233,6 +255,9 @@ class DataGenerator(object):
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
......@@ -241,7 +266,8 @@ class DataGenerator(object):
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self._rng.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
......@@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
parser = argparse.ArgumentParser(
description='Downloads and prepare LibriSpeech dataset.')
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/Libri",
......
文件模式从 100755 更改为 100644
......@@ -8,8 +8,7 @@ from itertools import groupby
def ctc_best_path_decode(probs_seq, vocabulary):
"""
Best path decoding, also called argmax decoding or greedy decoding.
"""Best path decoding, also called argmax decoding or greedy decoding.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
......@@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def ctc_decode(probs_seq, vocabulary, method):
"""
CTC-like sequence decoding from a sequence of likelihood probablilites.
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
......
......@@ -10,9 +10,9 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from decoder import ctc_decode
import utils
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.')
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=10,
......@@ -62,9 +62,7 @@ args = parser.parse_args()
def infer():
"""
Max-ctc-decoding for DeepSpeech2.
"""
"""Max-ctc-decoding for DeepSpeech2."""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
......@@ -98,7 +96,7 @@ def infer():
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
sortagrad=False,
batch_shuffle=False)
shuffle_method=None)
infer_data = batch_reader().next()
# run inference
......@@ -123,6 +121,7 @@ def infer():
def main():
utils.print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
infer()
......
......@@ -12,6 +12,7 @@ import distutils.util
import paddle.v2 as paddle
from model import deep_speech2
from data_utils.data import DataGenerator
import utils
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -51,6 +52,12 @@ parser.add_argument(
default=True,
type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--shuffle_method",
default='instance_shuffle',
type=str,
help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=4,
......@@ -93,9 +100,7 @@ args = parser.parse_args()
def train():
"""
DeepSpeech2 training.
"""
"""DeepSpeech2 training."""
# initialize data generator
def data_generator():
......@@ -143,13 +148,15 @@ def train():
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
min_batch_size=args.trainer_count,
sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True)
shuffle_method=args.shuffle_method)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
min_batch_size=1, # must be 1, but will have errors.
sortagrad=False,
batch_shuffle=False)
shuffle_method=None)
# create event handler
def event_handler(event):
......@@ -157,11 +164,11 @@ def train():
if isinstance(event, paddle.event.EndIteration):
cost_sum += event.cost
cost_counter += 1
if event.batch_id % 50 == 0:
print("\nPass: %d, Batch: %d, TrainCost: %f" %
(event.pass_id, event.batch_id, cost_sum / cost_counter))
if (event.batch_id + 1) % 100 == 0:
print("\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0
with gzip.open("params_tmp.tar.gz", 'w') as f:
with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f)
else:
sys.stdout.write('.')
......@@ -184,6 +191,7 @@ def train():
def main():
utils.print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train()
......
"""Contains common utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----- Configuration Arguments -----")
for arg, value in vars(args).iteritems():
print("%s: %s" % (arg, value))
print("------------------------------------")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册