提交 397b2fc2 编写于 作者: Y yangyaming

Merge branch 'develop' of https://github.com/PaddlePaddle/models into fix-81

......@@ -18,9 +18,14 @@ For some machines, we also need to install libsndfile1. Details to be added.
```
cd data
python librispeech.py
cat manifest.libri.train-* > manifest.libri.train-all
cd ..
```
After running librispeech.py, we have several "manifest" json files named with a prefix `manifest.libri.`. A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcription text, audio duration) of each audio file within the data set, in json format.
By `cat manifest.libri.train-* > manifest.libri.train-all`, we simply merge the three seperate sample sets of LibriSpeech (train-clean-100, train-clean-360, train-other-500) into one training set. This is a simple way for merging different data sets.
More help for arguments:
```
......@@ -32,13 +37,13 @@ python librispeech.py --help
For GPU Training:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 --train_manifest_path ./data/manifest.libri.train-all
```
For CPU Training:
```
python train.py --trainer_count 8 --use_gpu False
python train.py --trainer_count 8 --use_gpu False -- train_manifest_path ./data/manifest.libri.train-all
```
More help for arguments:
......
......@@ -8,6 +8,7 @@ import json
import random
import soundfile
import numpy as np
import itertools
import os
RANDOM_SEED = 0
......@@ -62,6 +63,7 @@ class DataGenerator(object):
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
......@@ -245,10 +247,42 @@ class DataGenerator(object):
new_batch.append((padded_audio, text))
return new_batch
def instance_reader_creator(self,
manifest_path,
sort_by_duration=True,
shuffle=False):
def __batch_shuffle__(self, manifest, batch_size):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
......@@ -256,32 +290,13 @@ class DataGenerator(object):
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")
def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if sort_by_duration:
manifest.sort(key=lambda x: x["duration"])
if shuffle:
self.__random__.shuffle(manifest)
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
......@@ -296,8 +311,8 @@ class DataGenerator(object):
batch_size,
padding_to=-1,
flatten=False,
sort_by_duration=True,
shuffle=False):
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
......@@ -317,20 +332,32 @@ class DataGenerator(object):
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
instance_reader = self.instance_reader_creator(
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
sort_by_duration=sort_by_duration,
shuffle=shuffle)
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
instance_reader = self.instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
......@@ -339,6 +366,7 @@ class DataGenerator(object):
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1
return batch_reader
......
"""
Download, unpack and create manifest for Librespeech dataset.
Download, unpack and create manifest json files for the Librespeech dataset.
Manifest is a json file with each line containing one audio clip filepath,
its transcription text string, and its duration. It servers as a unified
interfance to organize different data sets.
A manifest is a json file summarizing filelist in a data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file in the data set.
"""
import paddle.v2 as paddle
from paddle.v2.dataset.common import md5file
import distutils.util
import os
import wget
import tarfile
......@@ -27,7 +28,9 @@ URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
......@@ -44,6 +47,13 @@ parser.add_argument(
default="manifest.libri",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument(
"--full_download",
default="True",
type=distutils.util.strtobool,
help="Download all datasets for Librispeech."
" If False, only download a minimal requirement (test-clean, dev-clean"
" train-clean-100). (default: %(default)s)")
args = parser.parse_args()
......@@ -57,7 +67,10 @@ def download(url, md5sum, target_dir):
print("Downloading %s ..." % url)
wget.download(url, target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
assert md5file(filepath) == md5sum, "MD5 checksum failed."
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
......@@ -69,21 +82,17 @@ def unpack(filepath, target_dir):
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
return target_dir
def create_manifest(data_dir, manifest_path):
"""
Create a manifest file summarizing the dataset (list of filepath and meta
data).
Each line of the manifest contains one audio clip filepath, its
transcription text string, and its duration. Manifest file servers as a
unified interfance to organize data sets.
Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
for subfolder, _, filelist in os.walk(data_dir):
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
]
......@@ -111,9 +120,16 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""
Download, unpack and create summmary manifest file.
"""
filepath = download(url, md5sum, target_dir)
unpacked_dir = unpack(filepath, target_dir)
create_manifest(unpacked_dir, manifest_path)
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
# download
filepath = download(url, md5sum, target_dir)
# unpack
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(target_dir, manifest_path)
def main():
......@@ -132,6 +148,27 @@ def main():
md5sum=MD5_TRAIN_CLEAN_100,
target_dir=os.path.join(args.target_dir, "train-clean-100"),
manifest_path=args.manifest_prefix + ".train-clean-100")
if args.full_download:
prepare_dataset(
url=URL_TEST_OTHER,
md5sum=MD5_TEST_OTHER,
target_dir=os.path.join(args.target_dir, "test-other"),
manifest_path=args.manifest_prefix + ".test-other")
prepare_dataset(
url=URL_DEV_OTHER,
md5sum=MD5_DEV_OTHER,
target_dir=os.path.join(args.target_dir, "dev-other"),
manifest_path=args.manifest_prefix + ".dev-other")
prepare_dataset(
url=URL_TRAIN_CLEAN_360,
md5sum=MD5_TRAIN_CLEAN_360,
target_dir=os.path.join(args.target_dir, "train-clean-360"),
manifest_path=args.manifest_prefix + ".train-clean-360")
prepare_dataset(
url=URL_TRAIN_OTHER_500,
md5sum=MD5_TRAIN_OTHER_500,
target_dir=os.path.join(args.target_dir, "train-other-500"),
manifest_path=args.manifest_prefix + ".train-other-500")
if __name__ == '__main__':
......
......@@ -11,6 +11,7 @@ import sys
from model import deep_speech2
from audio_data_utils import DataGenerator
import numpy as np
import os
#TODO: add WER metric
......@@ -78,6 +79,13 @@ parser.add_argument(
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--init_model_path",
default=None,
type=str,
help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)")
args = parser.parse_args()
......@@ -85,23 +93,27 @@ def train():
"""
DeepSpeech2 training.
"""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
def data_generator():
return DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
train_generator = data_generator()
test_generator = data_generator()
# create network config
dict_size = data_generator.vocabulary_size()
dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
......@@ -114,36 +126,30 @@ def train():
rnn_size=args.rnn_layer_size,
is_inference=False)
# create parameters and optimizer
parameters = paddle.parameters.create(cost)
# create/load parameters and optimizer
if args.init_model_path is None:
parameters = paddle.parameters.create(cost)
else:
if not os.path.isfile(args.init_model_path):
raise IOError("Invalid model!")
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.init_model_path))
optimizer = paddle.optimizer.Adam(
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)
# prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
sortagrad=True if args.init_model_path is None else False,
batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
batch_shuffle=False)
feeding = train_generator.data_name_feeding()
# create event handler
def event_handler(event):
......@@ -169,17 +175,8 @@ def train():
time.time() - start_time, event.pass_id, result.cost)
# run train
# first pass with sortagrad
if args.use_sortagrad:
trainer.train(
reader=train_batch_reader_sortagrad,
event_handler=event_handler,
num_passes=1,
feeding=feeding)
args.num_passes -= 1
# other passes without sortagrad
trainer.train(
reader=train_batch_reader_nosortagrad,
reader=train_batch_reader,
event_handler=event_handler,
num_passes=args.num_passes,
feeding=feeding)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册