提交 70eb4001 编写于 作者: X Xinghai Sun

Refactor whole data preprocessor for DS2 (re-design classes, re-organize dir,...

Refactor whole data preprocessor for DS2 (re-design classes, re-organize dir, add augmentaion interfaces etc.).

1. Refactor data preprocessor with new added class AudioSegment, SpeechSegment, TextFeaturizer, AudioFeaturizer, SpeechFeaturizer.
2. Add data augmentation interfaces and class AugmentorBase, AugmentationPipeline, VolumnPerturbAugmentor etc..
3. Seperate normalizer's mean and std computing from training, by adding FeatureNormalizer and a seperate tool compute_mean_std.py.
4. Re-organize directory.
上级 d91dab00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from data_utils.normalizer import FeatureNormalizer
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
parser = argparse.ArgumentParser(
description='Computing mean and stddev for feature normalizer.')
parser.add_argument(
"--manifest_path",
default='datasets/manifest.train',
type=str,
help="Manifest path for computing normalizer's mean and stddev."
"(default: %(default)s)")
parser.add_argument(
"--num_samples",
default=500,
type=int,
help="Number of samples for computing mean and stddev. "
"(default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
parser.add_argument(
"--output_file",
default='mean_std.npz',
type=str,
help="Filepath to write mean and std to (.npz)."
"(default: %(default)s)")
args = parser.parse_args()
def main():
augmentation_pipeline = AugmentationPipeline(args.augmentation_config)
audio_featurizer = AudioFeaturizer()
def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment)
return audio_featurizer.featurize(audio_segment)
normalizer = FeatureNormalizer(
mean_std_filepath=None,
manifest_path=args.manifest_path,
featurize_func=augment_and_featurize,
num_samples=args.num_samples)
normalizer.write_to_file(args.output_file)
if __name__ == '__main__':
main()
import numpy as np
import io
import soundfile
class AudioSegment(object):
"""Monaural audio segment abstraction.
"""
def __init__(self, samples, sample_rate):
if not samples.dtype == np.float32:
raise ValueError("Sample data type of [%s] is not supported.")
self._samples = samples
self._sample_rate = sample_rate
if self._samples.ndim >= 2:
self._samples = np.mean(self._samples, 1)
@classmethod
def from_file(cls, filepath):
samples, sample_rate = soundfile.read(filepath, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def from_bytes(cls, bytes):
samples, sample_rate = soundfile.read(
io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
def apply_gain(self, gain):
self.samples *= 10.**(gain / 20.)
def resample(self, target_sample_rate):
raise NotImplementedError()
def change_speed(self, rate):
raise NotImplementedError()
@property
def samples(self):
return self._samples.copy()
@property
def sample_rate(self):
return self._sample_rate
@property
def duration(self):
return self._samples.shape[0] / float(self._sample_rate)
class SpeechSegment(AudioSegment):
def __init__(self, samples, sample_rate, transcript):
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
@classmethod
def from_file(cls, filepath, transcript):
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def from_bytes(cls, bytes, transcript):
audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript)
@property
def transcript(self):
return self._transcript
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor
class AugmentationPipeline(object):
def __init__(self, augmentation_config, random_seed=0):
self._rng = random.Random(random_seed)
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config)
def transform_audio(self, audio_segment):
for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) <= rate:
augmentor.transform_audio(audio_segment)
def _parse_pipeline_from(self, config_json):
try:
configs = json.loads(config_json)
except Exception as e:
raise ValueError("Augmentation config json format error: "
"%s" % str(e))
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in configs
]
rates = [config["rate"] for config in configs]
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
if augmentor_type == "volumn":
return VolumnPerturbAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta, abstractmethod
class AugmentorBase(object):
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self):
pass
@abstractmethod
def transform_audio(self, audio_segment):
pass
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
from data_utils.augmentor.base import AugmentorBase
class VolumnPerturbAugmentor(AugmentorBase):
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
self._min_gain_dBFS = min_gain_dBFS
self._max_gain_dBFS = max_gain_dBFS
self._rng = rng
def transform_audio(self, audio_segment):
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS)
audio_segment.apply_gain(gain)
...@@ -2,17 +2,19 @@ ...@@ -2,17 +2,19 @@
Providing basic audio data preprocessing pipeline, and offering Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces. both instance-level and batch-level data reader interfaces.
""" """
import paddle.v2 as paddle
import logging from __future__ import absolute_import
import json from __future__ import division
from __future__ import print_function
import random import random
import soundfile
import numpy as np import numpy as np
import itertools import paddle.v2 as paddle
import os from data_utils import utils
from data_utils.augmentor.augmentation import AugmentationPipeline
RANDOM_SEED = 0 from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
logger = logging.getLogger(__name__) from data_utils.audio import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
class DataGenerator(object): class DataGenerator(object):
...@@ -51,173 +53,135 @@ class DataGenerator(object): ...@@ -51,173 +53,135 @@ class DataGenerator(object):
def __init__(self, def __init__(self,
vocab_filepath, vocab_filepath,
normalizer_manifest_path, mean_std_filepath,
normalizer_num_samples=100, augmentation_config='{}',
max_duration=20.0, max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_frequency=None): max_freq=None,
self.__max_duration__ = max_duration random_seed=0):
self.__min_duration__ = min_duration self._max_duration = max_duration
self.__stride_ms__ = stride_ms self._min_duration = min_duration
self.__window_ms__ = window_ms self._normalizer = FeatureNormalizer(mean_std_filepath)
self.__max_frequency__ = max_frequency self._augmentation_pipeline = AugmentationPipeline(
self.__epoc__ = 0 augmentation_config=augmentation_config, random_seed=random_seed)
self.__random__ = random.Random(RANDOM_SEED) self._speech_featurizer = SpeechFeaturizer(
# load vocabulary (dictionary) vocab_filepath=vocab_filepath,
self.__vocab_dict__, self.__vocab_list__ = \ stride_ms=stride_ms,
self.__load_vocabulary_from_file__(vocab_filepath) window_ms=window_ms,
# collect normalizer statistics max_freq=max_freq,
self.__mean__, self.__std__ = self.__collect_normalizer_statistics__( random_seed=random_seed)
manifest_path=normalizer_manifest_path, self._rng = random.Random(random_seed)
num_samples=normalizer_num_samples) self._epoch = 0
def __audio_featurize__(self, audio_filename): def batch_reader_creator(self,
""" manifest_path,
Preprocess audio data, including feature extraction, normalization etc.. batch_size,
padding_to=-1,
flatten=False,
sortagrad=False,
batch_shuffle=False):
""" """
features = self.__audio_basic_featurize__(audio_filename) Batch data reader creator for audio data. Creat a callable function to
return self.__normalize__(features) produce batches of data.
def __text_featurize__(self, text): Audio features will be padded with zeros to make each instance in the
""" batch to share the same audio feature shape.
Preprocess text data, including tokenizing and token indexing etc..
"""
return self.__convert_text_to_char_index__(
text=text, vocabulary=self.__vocab_dict__)
def __audio_basic_featurize__(self, audio_filename): :param manifest_path: Filepath of manifest for audio clip files.
""" :type manifest_path: basestring
Compute basic (without normalization etc.) features for audio data. :param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `_batch_shuffle` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
""" """
return self.__spectrogram_from_file__(
filename=audio_filename,
stride_ms=self.__stride_ms__,
window_ms=self.__window_ms__,
max_freq=self.__max_frequency__)
def __collect_normalizer_statistics__(self, manifest_path, num_samples=100): def batch_reader():
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest # read manifest
manifest = self.__read_manifest__( manifest = utils.read_manifest(
manifest_path=manifest_path, manifest_path=manifest_path,
max_duration=self.__max_duration__, max_duration=self._max_duration,
min_duration=self.__min_duration__) min_duration=self._min_duration)
# sample for statistics # sort (by duration) or batch-wise shuffle the manifest
sampled_manifest = self.__random__.sample(manifest, num_samples) if self._epoch == 0 and sortagrad:
# extract spectrogram feature manifest.sort(key=lambda x: x["duration"])
features = [] elif batch_shuffle:
for instance in sampled_manifest: manifest = self._batch_shuffle(manifest, batch_size)
spectrogram = self.__audio_basic_featurize__( # prepare batches
instance["audio_filepath"]) instance_reader = self._instance_reader_creator(manifest)
features.append(spectrogram) batch = []
features = np.hstack(features) for instance in instance_reader():
mean = np.mean(features, axis=1).reshape([-1, 1]) batch.append(instance)
std = np.std(features, axis=1).reshape([-1, 1]) if len(batch) == batch_size:
return mean, std yield self._padding_batch(batch, padding_to, flatten)
batch = []
if len(batch) > 0:
yield self._padding_batch(batch, padding_to, flatten)
self._epoch += 1
def __normalize__(self, features, eps=1e-14): return batch_reader
"""
Normalize features to be of zero mean and unit stddev.
"""
return (features - self.__mean__) / (self.__std__ + eps)
def __spectrogram_from_file__(self, @property
filename, def feeding(self):
stride_ms=10.0, """Returns data_reader's feeding dict."""
window_ms=20.0, return {"audio_spectrogram": 0, "transcript_text": 1}
max_freq=None,
eps=1e-14):
"""
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio, sample_rate = soundfile.read(filename)
if audio.ndim >= 2:
audio = np.mean(audio, 1)
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
spectrogram, freqs = self.__extract_spectrogram__(
audio,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(spectrogram[:ind, :] + eps)
def __extract_spectrogram__(self, samples, window_size, stride_size, @property
sample_rate): def vocab_size(self):
""" """Returns vocabulary size."""
Compute the spectrogram by FFT for a discrete real signal. return self._speech_featurizer.vocab_size
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def __load_vocabulary_from_file__(self, vocabulary_path): @property
""" def vocab_list(self):
Load vocabulary from file. """Returns vocabulary list."""
""" return self._speech_featurizer.vocab_list
if not os.path.exists(vocabulary_path):
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
vocab_lines = []
with open(vocabulary_path, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
def __convert_text_to_char_index__(self, text, vocabulary): def _process_utterance(self, filename, transcript):
""" speech_segment = SpeechSegment.from_file(filename, transcript)
Convert text string to a list of character index integers. self._augmentation_pipeline.transform_audio(speech_segment)
""" specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
return [vocabulary[w] for w in text] specgram = self._normalizer.apply(specgram)
return specgram, text_ids
def __read_manifest__(self, manifest_path, max_duration, min_duration): def _instance_reader_creator(self, manifest):
""" """
Load and parse manifest file. Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
""" """
manifest = []
for json_line in open(manifest_path):
try:
json_data = json.loads(json_line)
except Exception as e:
raise ValueError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
def __padding_batch__(self, batch, padding_to=-1, flatten=False): def reader():
for instance in manifest:
yield self._process_utterance(instance["audio_filepath"],
instance["text"])
return reader
def _padding_batch(self, batch, padding_to=-1, flatten=False):
""" """
Padding audio part of features (only in the time axis -- column axis) Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same with zeros, to make each instance in the batch share the same
...@@ -247,7 +211,7 @@ class DataGenerator(object): ...@@ -247,7 +211,7 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def __batch_shuffle__(self, manifest, batch_size): def _batch_shuffle(self, manifest, batch_size):
""" """
The instances have different lengths and they cannot be The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually combined into a single matrix multiplication. It usually
...@@ -273,139 +237,11 @@ class DataGenerator(object): ...@@ -273,139 +237,11 @@ class DataGenerator(object):
:rtype: list :rtype: list
""" """
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1) shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest) self._rng.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ())) batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest) res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:]) batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len]) batch_manifest.extend(manifest[0:shift_len])
return batch_manifest return batch_manifest
def instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
def reader():
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
instance["audio_filepath"])
transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript)
return reader
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
instance_reader = self.instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self.__padding_batch__(batch, padding_to, flatten)
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1
return batch_reader
def vocabulary_size(self):
"""
Get vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self.__vocab_list__)
def vocabulary_dict(self):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return self.__vocab_dict__
def vocabulary_list(self):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return self.__vocab_list__
def data_name_feeding(self):
"""
Get feeddings (data field name and corresponding field id).
:return: Feeding dict.
:rtype: dict
"""
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
return feeding
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
from data_utils import utils
from data_utils.audio import AudioSegment
class AudioFeaturizer(object):
def __init__(self,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
random_seed=0):
self._specgram_type = specgram_type
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
def featurize(self, audio_segment):
return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate)
def _compute_specgram(self, samples, sample_rate):
if self._specgram_type == 'linear':
return self._compute_linear_specgram(
samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(specgram[:ind, :] + eps)
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram by FFT for a discrete real signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
def __init__(self,
vocab_filepath,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
random_seed=0):
self._audio_featurizer = AudioFeaturizer(
specgram_type, stride_ms, window_ms, max_freq, random_seed)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
audio_feature = self._audio_featurizer.featurize(speech_segment)
text_ids = self._text_featurizer.text2ids(speech_segment.transcript)
return audio_feature, text_ids
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
class TextFeaturizer(object):
def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
def text2ids(self, text):
tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens]
def ids2text(self, ids):
return ''.join([self._vocab_list[id] for id in ids])
@property
def vocab_size(self):
return len(self._vocab_list)
@property
def vocab_list(self):
return self._vocab_list
def _char_tokenize(self, text):
return list(text.strip())
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import data_utils.utils as utils
from data_utils.audio import AudioSegment
class FeatureNormalizer(object):
def __init__(self,
mean_std_filepath,
manifest_path=None,
featurize_func=None,
num_samples=500,
random_seed=0):
if not mean_std_filepath:
if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.")
self._rng = random.Random(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples)
else:
self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-14):
"""Normalize features to be of zero mean and unit stddev."""
return (features - self._mean) / (self._std + eps)
def write_to_file(self, filepath):
np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath):
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
self._std = npzfile["std"]
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
manifest = utils.read_manifest(manifest_path)
sampled_manifest = self._rng.sample(manifest, num_samples)
features = []
for instance in sampled_manifest:
features.append(
featurize_func(
AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features)
self._mean = np.mean(features, axis=1).reshape([-1, 1])
self._std = np.std(features, axis=1).reshape([-1, 1])
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file."""
manifest = []
for json_line in open(manifest_path):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
...@@ -44,7 +44,7 @@ parser.add_argument( ...@@ -44,7 +44,7 @@ parser.add_argument(
help="Directory to save the dataset. (default: %(default)s)") help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--manifest_prefix", "--manifest_prefix",
default="manifest.libri", default="manifest",
type=str, type=str,
help="Filepath prefix for output manifests. (default: %(default)s)") help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument( parser.add_argument(
......
cd librispeech
python librispeech.py
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
cd -
cat librispeech/manifest.train* | shuf > manifest.train
cat librispeech/manifest.dev-clean > manifest.dev
cat librispeech/manifest.test-clean > manifest.test
echo "All done."
...@@ -2,11 +2,15 @@ ...@@ -2,11 +2,15 @@
Inference for a simplifed version of Baidu DeepSpeech2 model. Inference for a simplifed version of Baidu DeepSpeech2 model.
""" """
import paddle.v2 as paddle from __future__ import absolute_import
import distutils.util from __future__ import division
from __future__ import print_function
import argparse import argparse
import gzip import gzip
from audio_data_utils import DataGenerator import distutils.util
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import ctc_decode from decoder import ctc_decode
...@@ -38,13 +42,13 @@ parser.add_argument( ...@@ -38,13 +42,13 @@ parser.add_argument(
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--normalizer_manifest_path", "--mean_std_filepath",
default='data/manifest.libri.train-clean-100', default='mean_std.npz',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-clean', default='datasets/manifest.test',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -54,7 +58,7 @@ parser.add_argument( ...@@ -54,7 +58,7 @@ parser.add_argument(
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--vocab_filepath", "--vocab_filepath",
default='data/eng_vocab.txt', default='datasets/vocab/eng_vocab.txt',
type=str, type=str,
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
...@@ -67,28 +71,22 @@ def infer(): ...@@ -67,28 +71,22 @@ def infer():
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path, mean_std_filepath=args.mean_std_filepath,
normalizer_num_samples=200, augmentation_config='{}')
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
# create network config # create network config
dict_size = data_generator.vocabulary_size() # paddle.data_type.dense_array is used for variable batch input.
vocab_list = data_generator.vocabulary_list() # The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs = deep_speech2( output_probs = deep_speech2(
audio_data=audio_data, audio_data=audio_data,
text_data=text_data, text_data=text_data,
dict_size=dict_size, dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size, rnn_size=args.rnn_layer_size,
...@@ -99,31 +97,30 @@ def infer(): ...@@ -99,31 +97,30 @@ def infer():
gzip.open(args.model_filepath)) gzip.open(args.model_filepath))
# prepare infer data # prepare infer data
feeding = data_generator.data_name_feeding() batch_reader = data_generator.batch_reader_creator(
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path, manifest_path=args.decode_manifest_path,
batch_size=args.num_samples, batch_size=args.num_samples,
padding_to=2000, sortagrad=False,
flatten=True, batch_shuffle=False)
sort_by_duration=False, infer_data = batch_reader().next()
shuffle=False)
infer_data = test_batch_reader().next()
# run inference # run inference
infer_results = paddle.infer( infer_results = paddle.infer(
output_layer=output_probs, parameters=parameters, input=infer_data) output_layer=output_probs, parameters=parameters, input=infer_data)
num_steps = len(infer_results) / len(infer_data) num_steps = len(infer_results) // len(infer_data)
probs_split = [ probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps] infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data)) for i in xrange(len(infer_data))
] ]
# decode and print # decode and print
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_decode( output_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path") probs_seq=probs,
vocabulary=data_generator.vocab_list,
method="best_path")
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [data_generator.vocab_list[index] for index in infer_data[i][1]])
print("Target Transcription: %s \nOutput Transcription: %s \n" % print("Target Transcription: %s \nOutput Transcription: %s \n" %
(target_transcription, output_transcription)) (target_transcription, output_transcription))
......
...@@ -2,21 +2,21 @@ ...@@ -2,21 +2,21 @@
Trainer for a simplifed version of Baidu DeepSpeech2 model. Trainer for a simplifed version of Baidu DeepSpeech2 model.
""" """
import paddle.v2 as paddle from __future__ import absolute_import
import distutils.util from __future__ import division
from __future__ import print_function
import sys
import os
import argparse import argparse
import gzip import gzip
import time import time
import sys import distutils.util
import paddle.v2 as paddle
from model import deep_speech2 from model import deep_speech2
from audio_data_utils import DataGenerator from data_utils.data import DataGenerator
import numpy as np
import os
#TODO: add WER metric parser = argparse.ArgumentParser(description=__doc__)
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 trainer.')
parser.add_argument( parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.") "--batch_size", default=32, type=int, help="Minibatch size.")
parser.add_argument( parser.add_argument(
...@@ -51,7 +51,7 @@ parser.add_argument( ...@@ -51,7 +51,7 @@ parser.add_argument(
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--use_sortagrad", "--use_sortagrad",
default=False, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -60,23 +60,23 @@ parser.add_argument( ...@@ -60,23 +60,23 @@ parser.add_argument(
type=int, type=int,
help="Trainer number. (default: %(default)s)") help="Trainer number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--normalizer_manifest_path", "--mean_std_filepath",
default='data/manifest.libri.train-clean-100', default='mean_std.npz',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--train_manifest_path", "--train_manifest_path",
default='data/manifest.libri.train-clean-100', default='datasets/manifest.train',
type=str, type=str,
help="Manifest path for training. (default: %(default)s)") help="Manifest path for training. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--dev_manifest_path", "--dev_manifest_path",
default='data/manifest.libri.dev-clean', default='datasets/manifest.dev',
type=str, type=str,
help="Manifest path for validation. (default: %(default)s)") help="Manifest path for validation. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--vocab_filepath", "--vocab_filepath",
default='data/eng_vocab.txt', default='datasets/vocab/eng_vocab.txt',
type=str, type=str,
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -86,6 +86,12 @@ parser.add_argument( ...@@ -86,6 +86,12 @@ parser.add_argument(
help="If set None, the training will start from scratch. " help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from " "Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)") "the existing model of this path. (default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
...@@ -98,29 +104,26 @@ def train(): ...@@ -98,29 +104,26 @@ def train():
def data_generator(): def data_generator():
return DataGenerator( return DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path, mean_std_filepath=args.mean_std_filepath,
normalizer_num_samples=200, augmentation_config=args.augmentation_config)
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
train_generator = data_generator() train_generator = data_generator()
test_generator = data_generator() test_generator = data_generator()
# create network config # create network config
dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input. # paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape # The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch. # of input batch data will be induced during training.
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(
train_generator.vocab_size))
cost = deep_speech2( cost = deep_speech2(
audio_data=audio_data, audio_data=audio_data,
text_data=text_data, text_data=text_data,
dict_size=dict_size, dict_size=train_generator.vocab_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size, rnn_size=args.rnn_layer_size,
...@@ -143,13 +146,13 @@ def train(): ...@@ -143,13 +146,13 @@ def train():
train_batch_reader = train_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
sortagrad=True if args.init_model_path is None else False, sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True) batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
sortagrad=False,
batch_shuffle=False) batch_shuffle=False)
feeding = train_generator.data_name_feeding()
# create event handler # create event handler
def event_handler(event): def event_handler(event):
...@@ -158,8 +161,8 @@ def train(): ...@@ -158,8 +161,8 @@ def train():
cost_sum += event.cost cost_sum += event.cost
cost_counter += 1 cost_counter += 1
if event.batch_id % 50 == 0: if event.batch_id % 50 == 0:
print "\nPass: %d, Batch: %d, TrainCost: %f" % ( print("\nPass: %d, Batch: %d, TrainCost: %f" %
event.pass_id, event.batch_id, cost_sum / cost_counter) (event.pass_id, event.batch_id, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f: with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
...@@ -170,16 +173,17 @@ def train(): ...@@ -170,16 +173,17 @@ def train():
start_time = time.time() start_time = time.time()
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding) result = trainer.test(
print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % ( reader=test_batch_reader, feeding=test_generator.feeding)
time.time() - start_time, event.pass_id, result.cost) print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost))
# run train # run train
trainer.train( trainer.train(
reader=train_batch_reader, reader=train_batch_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes, num_passes=args.num_passes,
feeding=feeding) feeding=train_generator.feeding)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册