From 70eb40019d64e9cf201a6fb024f28358a5cbe88a Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 12 Jun 2017 23:19:40 +0800 Subject: [PATCH] Refactor whole data preprocessor for DS2 (re-design classes, re-organize dir, add augmentaion interfaces etc.). 1. Refactor data preprocessor with new added class AudioSegment, SpeechSegment, TextFeaturizer, AudioFeaturizer, SpeechFeaturizer. 2. Add data augmentation interfaces and class AugmentorBase, AugmentationPipeline, VolumnPerturbAugmentor etc.. 3. Seperate normalizer's mean and std computing from training, by adding FeatureNormalizer and a seperate tool compute_mean_std.py. 4. Re-organize directory. --- deep_speech_2/audio_data_utils.py | 411 ------------------ deep_speech_2/compute_mean_std.py | 56 +++ deep_speech_2/data_utils/__init__.py | 0 deep_speech_2/data_utils/audio.py | 68 +++ .../data_utils/augmentor/__init__.py | 0 .../data_utils/augmentor/augmentation.py | 38 ++ deep_speech_2/data_utils/augmentor/base.py | 17 + .../data_utils/augmentor/volumn_perturb.py | 17 + deep_speech_2/data_utils/data.py | 247 +++++++++++ .../data_utils/featurizer/__init__.py | 0 .../data_utils/featurizer/audio_featurizer.py | 86 ++++ .../featurizer/speech_featurizer.py | 32 ++ .../data_utils/featurizer/text_featurizer.py | 39 ++ deep_speech_2/data_utils/normalizer.py | 49 +++ deep_speech_2/data_utils/utils.py | 19 + .../librispeech}/librispeech.py | 2 +- deep_speech_2/datasets/run_all.sh | 13 + .../{data => datasets/vocab}/eng_vocab.txt | 0 deep_speech_2/infer.py | 61 ++- deep_speech_2/train.py | 74 ++-- 20 files changed, 750 insertions(+), 479 deletions(-) delete mode 100644 deep_speech_2/audio_data_utils.py create mode 100755 deep_speech_2/compute_mean_std.py create mode 100755 deep_speech_2/data_utils/__init__.py create mode 100755 deep_speech_2/data_utils/audio.py create mode 100755 deep_speech_2/data_utils/augmentor/__init__.py create mode 100755 deep_speech_2/data_utils/augmentor/augmentation.py create mode 100755 deep_speech_2/data_utils/augmentor/base.py create mode 100755 deep_speech_2/data_utils/augmentor/volumn_perturb.py create mode 100644 deep_speech_2/data_utils/data.py create mode 100755 deep_speech_2/data_utils/featurizer/__init__.py create mode 100755 deep_speech_2/data_utils/featurizer/audio_featurizer.py create mode 100755 deep_speech_2/data_utils/featurizer/speech_featurizer.py create mode 100755 deep_speech_2/data_utils/featurizer/text_featurizer.py create mode 100755 deep_speech_2/data_utils/normalizer.py create mode 100755 deep_speech_2/data_utils/utils.py rename deep_speech_2/{data => datasets/librispeech}/librispeech.py (99%) create mode 100755 deep_speech_2/datasets/run_all.sh rename deep_speech_2/{data => datasets/vocab}/eng_vocab.txt (100%) diff --git a/deep_speech_2/audio_data_utils.py b/deep_speech_2/audio_data_utils.py deleted file mode 100644 index 1cd29be1..00000000 --- a/deep_speech_2/audio_data_utils.py +++ /dev/null @@ -1,411 +0,0 @@ -""" - Providing basic audio data preprocessing pipeline, and offering - both instance-level and batch-level data reader interfaces. -""" -import paddle.v2 as paddle -import logging -import json -import random -import soundfile -import numpy as np -import itertools -import os - -RANDOM_SEED = 0 -logger = logging.getLogger(__name__) - - -class DataGenerator(object): - """ - DataGenerator provides basic audio data preprocessing pipeline, and offers - both instance-level and batch-level data reader interfaces. - Normalized FFT are used as audio features here. - - :param vocab_filepath: Vocabulary file path for indexing tokenized - transcriptions. - :type vocab_filepath: basestring - :param normalizer_manifest_path: Manifest filepath for collecting feature - normalization statistics, e.g. mean, std. - :type normalizer_manifest_path: basestring - :param normalizer_num_samples: Number of instances sampled for collecting - feature normalization statistics. - Default is 100. - :type normalizer_num_samples: int - :param max_duration: Audio clips with duration (in seconds) greater than - this will be discarded. Default is 20.0. - :type max_duration: float - :param min_duration: Audio clips with duration (in seconds) smaller than - this will be discarded. Default is 0.0. - :type min_duration: float - :param stride_ms: Striding size (in milliseconds) for generating frames. - Default is 10.0. - :type stride_ms: float - :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. - :type window_ms: float - :param max_frequency: Maximun frequency for FFT features. FFT features of - frequency larger than this will be discarded. - If set None, all features will be kept. - Default is None. - :type max_frequency: float - """ - - def __init__(self, - vocab_filepath, - normalizer_manifest_path, - normalizer_num_samples=100, - max_duration=20.0, - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_frequency=None): - self.__max_duration__ = max_duration - self.__min_duration__ = min_duration - self.__stride_ms__ = stride_ms - self.__window_ms__ = window_ms - self.__max_frequency__ = max_frequency - self.__epoc__ = 0 - self.__random__ = random.Random(RANDOM_SEED) - # load vocabulary (dictionary) - self.__vocab_dict__, self.__vocab_list__ = \ - self.__load_vocabulary_from_file__(vocab_filepath) - # collect normalizer statistics - self.__mean__, self.__std__ = self.__collect_normalizer_statistics__( - manifest_path=normalizer_manifest_path, - num_samples=normalizer_num_samples) - - def __audio_featurize__(self, audio_filename): - """ - Preprocess audio data, including feature extraction, normalization etc.. - """ - features = self.__audio_basic_featurize__(audio_filename) - return self.__normalize__(features) - - def __text_featurize__(self, text): - """ - Preprocess text data, including tokenizing and token indexing etc.. - """ - return self.__convert_text_to_char_index__( - text=text, vocabulary=self.__vocab_dict__) - - def __audio_basic_featurize__(self, audio_filename): - """ - Compute basic (without normalization etc.) features for audio data. - """ - return self.__spectrogram_from_file__( - filename=audio_filename, - stride_ms=self.__stride_ms__, - window_ms=self.__window_ms__, - max_freq=self.__max_frequency__) - - def __collect_normalizer_statistics__(self, manifest_path, num_samples=100): - """ - Compute feature normalization statistics, i.e. mean and stddev. - """ - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - # sample for statistics - sampled_manifest = self.__random__.sample(manifest, num_samples) - # extract spectrogram feature - features = [] - for instance in sampled_manifest: - spectrogram = self.__audio_basic_featurize__( - instance["audio_filepath"]) - features.append(spectrogram) - features = np.hstack(features) - mean = np.mean(features, axis=1).reshape([-1, 1]) - std = np.std(features, axis=1).reshape([-1, 1]) - return mean, std - - def __normalize__(self, features, eps=1e-14): - """ - Normalize features to be of zero mean and unit stddev. - """ - return (features - self.__mean__) / (self.__std__ + eps) - - def __spectrogram_from_file__(self, - filename, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - eps=1e-14): - """ - Laod audio data and calculate the log of spectrogram by FFT. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - audio, sample_rate = soundfile.read(filename) - if audio.ndim >= 2: - audio = np.mean(audio, 1) - if max_freq is None: - max_freq = sample_rate / 2 - if max_freq > sample_rate / 2: - raise ValueError("max_freq must be greater than half of " - "sample rate.") - if stride_ms > window_ms: - raise ValueError("Stride size must not be greater than " - "window size.") - stride_size = int(0.001 * sample_rate * stride_ms) - window_size = int(0.001 * sample_rate * window_ms) - spectrogram, freqs = self.__extract_spectrogram__( - audio, - window_size=window_size, - stride_size=stride_size, - sample_rate=sample_rate) - ind = np.where(freqs <= max_freq)[0][-1] + 1 - return np.log(spectrogram[:ind, :] + eps) - - def __extract_spectrogram__(self, samples, window_size, stride_size, - sample_rate): - """ - Compute the spectrogram by FFT for a discrete real signal. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - # extract strided windows - truncate_size = (len(samples) - window_size) % stride_size - samples = samples[:len(samples) - truncate_size] - nshape = (window_size, (len(samples) - window_size) // stride_size + 1) - nstrides = (samples.strides[0], samples.strides[0] * stride_size) - windows = np.lib.stride_tricks.as_strided( - samples, shape=nshape, strides=nstrides) - assert np.all( - windows[:, 1] == samples[stride_size:(stride_size + window_size)]) - # window weighting, squared Fast Fourier Transform (fft), scaling - weighting = np.hanning(window_size)[:, None] - fft = np.fft.rfft(windows * weighting, axis=0) - fft = np.absolute(fft)**2 - scale = np.sum(weighting**2) * sample_rate - fft[1:-1, :] *= (2.0 / scale) - fft[(0, -1), :] /= scale - # prepare fft frequency list - freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) - return fft, freqs - - def __load_vocabulary_from_file__(self, vocabulary_path): - """ - Load vocabulary from file. - """ - if not os.path.exists(vocabulary_path): - raise ValueError("Vocabulary file %s not found.", vocabulary_path) - vocab_lines = [] - with open(vocabulary_path, 'r') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] - vocab_dict = dict( - [(token, id) for (id, token) in enumerate(vocab_list)]) - return vocab_dict, vocab_list - - def __convert_text_to_char_index__(self, text, vocabulary): - """ - Convert text string to a list of character index integers. - """ - return [vocabulary[w] for w in text] - - def __read_manifest__(self, manifest_path, max_duration, min_duration): - """ - Load and parse manifest file. - """ - manifest = [] - for json_line in open(manifest_path): - try: - json_data = json.loads(json_line) - except Exception as e: - raise ValueError("Error reading manifest: %s" % str(e)) - if (json_data["duration"] <= max_duration and - json_data["duration"] >= min_duration): - manifest.append(json_data) - return manifest - - def __padding_batch__(self, batch, padding_to=-1, flatten=False): - """ - Padding audio part of features (only in the time axis -- column axis) - with zeros, to make each instance in the batch share the same - audio feature shape. - - If `padding_to` is set -1, the maximun column numbers in the batch will - be used as the target size. Otherwise, `padding_to` will be the target - size. Default is -1. - - If `flatten` is set True, audio data will be flatten to be a 1-dim - ndarray. Default is False. - """ - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, text in batch]) - if padding_to != -1: - if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be greater" - " or equal to the original instance length.") - max_length = padding_to - # padding - for audio, text in batch: - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - if flatten: - padded_audio = padded_audio.flatten() - new_batch.append((padded_audio, text)) - return new_batch - - def __batch_shuffle__(self, manifest, batch_size): - """ - The instances have different lengths and they cannot be - combined into a single matrix multiplication. It usually - sorts the training examples by length and combines only - similarly-sized instances into minibatches, pads with - silence when necessary so that all instances in a batch - have the same length. This batch shuffle fuction is used - to make similarly-sized instances into minibatches and - make a batch-wise shuffle. - - 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_size. - 4. Shuffle the minibatches. - - :param manifest: manifest file. - :type manifest: list - :param batch_size: Batch size. This size is also used for generate - a random number for batch shuffle. - :type batch_size: int - :return: batch shuffled mainifest. - :rtype: list - """ - manifest.sort(key=lambda x: x["duration"]) - shift_len = self.__random__.randint(0, batch_size - 1) - batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) - self.__random__.shuffle(batch_manifest) - batch_manifest = list(sum(batch_manifest, ())) - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) - return batch_manifest - - def instance_reader_creator(self, manifest): - """ - Instance reader creator for audio data. Creat a callable function to - produce instances of data. - - Instance: a tuple of a numpy ndarray of audio spectrogram and a list of - tokenized and indexed transcription text. - - :param manifest: Filepath of manifest for audio clip files. - :type manifest: basestring - :return: Data reader function. - :rtype: callable - """ - - def reader(): - # extract spectrogram feature - for instance in manifest: - spectrogram = self.__audio_featurize__( - instance["audio_filepath"]) - transcript = self.__text_featurize__(instance["text"]) - yield (spectrogram, transcript) - - return reader - - def batch_reader_creator(self, - manifest_path, - batch_size, - padding_to=-1, - flatten=False, - sortagrad=False, - batch_shuffle=False): - """ - Batch data reader creator for audio data. Creat a callable function to - produce batches of data. - - Audio features will be padded with zeros to make each instance in the - batch to share the same audio feature shape. - - :param manifest_path: Filepath of manifest for audio clip files. - :type manifest_path: basestring - :param batch_size: Instance number in a batch. - :type batch_size: int - :param padding_to: If set -1, the maximun column numbers in the batch - will be used as the target size for padding. - Otherwise, `padding_to` will be the target size. - Default is -1. - :type padding_to: int - :param flatten: If set True, audio data will be flatten to be a 1-dim - ndarray. Otherwise, 2-dim ndarray. Default is False. - :type flatten: bool - :param sortagrad: Sort the audio clips by duration in the first epoc - if set True. - :type sortagrad: bool - :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, but a - specific batch-wise shuffle. For more details, - please see `__batch_shuffle__` function. - :type batch_shuffle: bool - :return: Batch reader function, producing batches of data when called. - :rtype: callable - """ - - def batch_reader(): - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - - # sort (by duration) or shuffle manifest - if self.__epoc__ == 0 and sortagrad: - manifest.sort(key=lambda x: x["duration"]) - elif batch_shuffle: - manifest = self.__batch_shuffle__(manifest, batch_size) - - instance_reader = self.instance_reader_creator(manifest) - batch = [] - for instance in instance_reader(): - batch.append(instance) - if len(batch) == batch_size: - yield self.__padding_batch__(batch, padding_to, flatten) - batch = [] - if len(batch) > 0: - yield self.__padding_batch__(batch, padding_to, flatten) - self.__epoc__ += 1 - - return batch_reader - - def vocabulary_size(self): - """ - Get vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return len(self.__vocab_list__) - - def vocabulary_dict(self): - """ - Get vocabulary in dict. - - :return: Vocabulary in dict. - :rtype: dict - """ - return self.__vocab_dict__ - - def vocabulary_list(self): - """ - Get vocabulary in list. - - :return: Vocabulary in list - :rtype: list - """ - return self.__vocab_list__ - - def data_name_feeding(self): - """ - Get feeddings (data field name and corresponding field id). - - :return: Feeding dict. - :rtype: dict - """ - feeding = { - "audio_spectrogram": 0, - "transcript_text": 1, - } - return feeding diff --git a/deep_speech_2/compute_mean_std.py b/deep_speech_2/compute_mean_std.py new file mode 100755 index 00000000..b3015df7 --- /dev/null +++ b/deep_speech_2/compute_mean_std.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from data_utils.normalizer import FeatureNormalizer +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.audio_featurizer import AudioFeaturizer + +parser = argparse.ArgumentParser( + description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--manifest_path", + default='datasets/manifest.train', + type=str, + help="Manifest path for computing normalizer's mean and stddev." + "(default: %(default)s)") +parser.add_argument( + "--num_samples", + default=500, + type=int, + help="Number of samples for computing mean and stddev. " + "(default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='{}', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") +parser.add_argument( + "--output_file", + default='mean_std.npz', + type=str, + help="Filepath to write mean and std to (.npz)." + "(default: %(default)s)") +args = parser.parse_args() + + +def main(): + augmentation_pipeline = AugmentationPipeline(args.augmentation_config) + audio_featurizer = AudioFeaturizer() + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=args.manifest_path, + featurize_func=augment_and_featurize, + num_samples=args.num_samples) + normalizer.write_to_file(args.output_file) + + +if __name__ == '__main__': + main() diff --git a/deep_speech_2/data_utils/__init__.py b/deep_speech_2/data_utils/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/deep_speech_2/data_utils/audio.py b/deep_speech_2/data_utils/audio.py new file mode 100755 index 00000000..46b24120 --- /dev/null +++ b/deep_speech_2/data_utils/audio.py @@ -0,0 +1,68 @@ +import numpy as np +import io +import soundfile + + +class AudioSegment(object): + """Monaural audio segment abstraction. + """ + + def __init__(self, samples, sample_rate): + if not samples.dtype == np.float32: + raise ValueError("Sample data type of [%s] is not supported.") + self._samples = samples + self._sample_rate = sample_rate + if self._samples.ndim >= 2: + self._samples = np.mean(self._samples, 1) + + @classmethod + def from_file(cls, filepath): + samples, sample_rate = soundfile.read(filepath, dtype='float32') + return cls(samples, sample_rate) + + @classmethod + def from_bytes(cls, bytes): + samples, sample_rate = soundfile.read( + io.BytesIO(bytes), dtype='float32') + return cls(samples, sample_rate) + + def apply_gain(self, gain): + self.samples *= 10.**(gain / 20.) + + def resample(self, target_sample_rate): + raise NotImplementedError() + + def change_speed(self, rate): + raise NotImplementedError() + + @property + def samples(self): + return self._samples.copy() + + @property + def sample_rate(self): + return self._sample_rate + + @property + def duration(self): + return self._samples.shape[0] / float(self._sample_rate) + + +class SpeechSegment(AudioSegment): + def __init__(self, samples, sample_rate, transcript): + AudioSegment.__init__(self, samples, sample_rate) + self._transcript = transcript + + @classmethod + def from_file(cls, filepath, transcript): + audio = AudioSegment.from_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def from_bytes(cls, bytes, transcript): + audio = AudioSegment.from_bytes(bytes) + return cls(audio.samples, audio.sample_rate, transcript) + + @property + def transcript(self): + return self._transcript diff --git a/deep_speech_2/data_utils/augmentor/__init__.py b/deep_speech_2/data_utils/augmentor/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/deep_speech_2/data_utils/augmentor/augmentation.py b/deep_speech_2/data_utils/augmentor/augmentation.py new file mode 100755 index 00000000..3a1426a1 --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/augmentation.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import random +from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor + + +class AugmentationPipeline(object): + def __init__(self, augmentation_config, random_seed=0): + self._rng = random.Random(random_seed) + self._augmentors, self._rates = self._parse_pipeline_from( + augmentation_config) + + def transform_audio(self, audio_segment): + for augmentor, rate in zip(self._augmentors, self._rates): + if self._rng.uniform(0., 1.) <= rate: + augmentor.transform_audio(audio_segment) + + def _parse_pipeline_from(self, config_json): + try: + configs = json.loads(config_json) + except Exception as e: + raise ValueError("Augmentation config json format error: " + "%s" % str(e)) + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in configs + ] + rates = [config["rate"] for config in configs] + return augmentors, rates + + def _get_augmentor(self, augmentor_type, params): + if augmentor_type == "volumn": + return VolumnPerturbAugmentor(self._rng, **params) + else: + raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/deep_speech_2/data_utils/augmentor/base.py b/deep_speech_2/data_utils/augmentor/base.py new file mode 100755 index 00000000..e801b9b1 --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/base.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from abc import ABCMeta, abstractmethod + + +class AugmentorBase(object): + __metaclass__ = ABCMeta + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def transform_audio(self, audio_segment): + pass diff --git a/deep_speech_2/data_utils/augmentor/volumn_perturb.py b/deep_speech_2/data_utils/augmentor/volumn_perturb.py new file mode 100755 index 00000000..dd1ba53a --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/volumn_perturb.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +from data_utils.augmentor.base import AugmentorBase + + +class VolumnPerturbAugmentor(AugmentorBase): + def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): + self._min_gain_dBFS = min_gain_dBFS + self._max_gain_dBFS = max_gain_dBFS + self._rng = rng + + def transform_audio(self, audio_segment): + gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + audio_segment.apply_gain(gain) diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py new file mode 100644 index 00000000..63000793 --- /dev/null +++ b/deep_speech_2/data_utils/data.py @@ -0,0 +1,247 @@ +""" + Providing basic audio data preprocessing pipeline, and offering + both instance-level and batch-level data reader interfaces. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import numpy as np +import paddle.v2 as paddle +from data_utils import utils +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.speech_featurizer import SpeechFeaturizer +from data_utils.audio import SpeechSegment +from data_utils.normalizer import FeatureNormalizer + + +class DataGenerator(object): + """ + DataGenerator provides basic audio data preprocessing pipeline, and offers + both instance-level and batch-level data reader interfaces. + Normalized FFT are used as audio features here. + + :param vocab_filepath: Vocabulary file path for indexing tokenized + transcriptions. + :type vocab_filepath: basestring + :param normalizer_manifest_path: Manifest filepath for collecting feature + normalization statistics, e.g. mean, std. + :type normalizer_manifest_path: basestring + :param normalizer_num_samples: Number of instances sampled for collecting + feature normalization statistics. + Default is 100. + :type normalizer_num_samples: int + :param max_duration: Audio clips with duration (in seconds) greater than + this will be discarded. Default is 20.0. + :type max_duration: float + :param min_duration: Audio clips with duration (in seconds) smaller than + this will be discarded. Default is 0.0. + :type min_duration: float + :param stride_ms: Striding size (in milliseconds) for generating frames. + Default is 10.0. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :type window_ms: float + :param max_frequency: Maximun frequency for FFT features. FFT features of + frequency larger than this will be discarded. + If set None, all features will be kept. + Default is None. + :type max_frequency: float + """ + + def __init__(self, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._max_duration = max_duration + self._min_duration = min_duration + self._normalizer = FeatureNormalizer(mean_std_filepath) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=augmentation_config, random_seed=random_seed) + self._speech_featurizer = SpeechFeaturizer( + vocab_filepath=vocab_filepath, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + random_seed=random_seed) + self._rng = random.Random(random_seed) + self._epoch = 0 + + def batch_reader_creator(self, + manifest_path, + batch_size, + padding_to=-1, + flatten=False, + sortagrad=False, + batch_shuffle=False): + """ + Batch data reader creator for audio data. Creat a callable function to + produce batches of data. + + Audio features will be padded with zeros to make each instance in the + batch to share the same audio feature shape. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param batch_size: Instance number in a batch. + :type batch_size: int + :param padding_to: If set -1, the maximun column numbers in the batch + will be used as the target size for padding. + Otherwise, `padding_to` will be the target size. + Default is -1. + :type padding_to: int + :param flatten: If set True, audio data will be flatten to be a 1-dim + ndarray. Otherwise, 2-dim ndarray. Default is False. + :type flatten: bool + :param sortagrad: Sort the audio clips by duration in the first epoc + if set True. + :type sortagrad: bool + :param batch_shuffle: Shuffle the audio clips if set True. It is + not a thorough instance-wise shuffle, but a + specific batch-wise shuffle. For more details, + please see `_batch_shuffle` function. + :type batch_shuffle: bool + :return: Batch reader function, producing batches of data when called. + :rtype: callable + """ + + def batch_reader(): + # read manifest + manifest = utils.read_manifest( + manifest_path=manifest_path, + max_duration=self._max_duration, + min_duration=self._min_duration) + # sort (by duration) or batch-wise shuffle the manifest + if self._epoch == 0 and sortagrad: + manifest.sort(key=lambda x: x["duration"]) + elif batch_shuffle: + manifest = self._batch_shuffle(manifest, batch_size) + # prepare batches + instance_reader = self._instance_reader_creator(manifest) + batch = [] + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self._padding_batch(batch, padding_to, flatten) + batch = [] + if len(batch) > 0: + yield self._padding_batch(batch, padding_to, flatten) + self._epoch += 1 + + return batch_reader + + @property + def feeding(self): + """Returns data_reader's feeding dict.""" + return {"audio_spectrogram": 0, "transcript_text": 1} + + @property + def vocab_size(self): + """Returns vocabulary size.""" + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + """Returns vocabulary list.""" + return self._speech_featurizer.vocab_list + + def _process_utterance(self, filename, transcript): + speech_segment = SpeechSegment.from_file(filename, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram = self._normalizer.apply(specgram) + return specgram, text_ids + + def _instance_reader_creator(self, manifest): + """ + Instance reader creator for audio data. Creat a callable function to + produce instances of data. + + Instance: a tuple of a numpy ndarray of audio spectrogram and a list of + tokenized and indexed transcription text. + + :param manifest: Filepath of manifest for audio clip files. + :type manifest: basestring + :return: Data reader function. + :rtype: callable + """ + + def reader(): + for instance in manifest: + yield self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return reader + + def _padding_batch(self, batch, padding_to=-1, flatten=False): + """ + Padding audio part of features (only in the time axis -- column axis) + with zeros, to make each instance in the batch share the same + audio feature shape. + + If `padding_to` is set -1, the maximun column numbers in the batch will + be used as the target size. Otherwise, `padding_to` will be the target + size. Default is -1. + + If `flatten` is set True, audio data will be flatten to be a 1-dim + ndarray. Default is False. + """ + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be greater" + " or equal to the original instance length.") + max_length = padding_to + # padding + for audio, text in batch: + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + if flatten: + padded_audio = padded_audio.flatten() + new_batch.append((padded_audio, text)) + return new_batch + + def _batch_shuffle(self, manifest, batch_size): + """ + The instances have different lengths and they cannot be + combined into a single matrix multiplication. It usually + sorts the training examples by length and combines only + similarly-sized instances into minibatches, pads with + silence when necessary so that all instances in a batch + have the same length. This batch shuffle fuction is used + to make similarly-sized instances into minibatches and + make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly remove `k` instances in order to make different mini-batches, + then make minibatches and each minibatch size is batch_size. + 4. Shuffle the minibatches. + + :param manifest: manifest file. + :type manifest: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :return: batch shuffled mainifest. + :rtype: list + """ + manifest.sort(key=lambda x: x["duration"]) + shift_len = self._rng.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + self._rng.shuffle(batch_manifest) + batch_manifest = list(sum(batch_manifest, ())) + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest diff --git a/deep_speech_2/data_utils/featurizer/__init__.py b/deep_speech_2/data_utils/featurizer/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/deep_speech_2/data_utils/featurizer/audio_featurizer.py b/deep_speech_2/data_utils/featurizer/audio_featurizer.py new file mode 100755 index 00000000..5d9c6883 --- /dev/null +++ b/deep_speech_2/data_utils/featurizer/audio_featurizer.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +from data_utils import utils +from data_utils.audio import AudioSegment + + +class AudioFeaturizer(object): + def __init__(self, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._specgram_type = specgram_type + self._stride_ms = stride_ms + self._window_ms = window_ms + self._max_freq = max_freq + + def featurize(self, audio_segment): + return self._compute_specgram(audio_segment.samples, + audio_segment.sample_rate) + + def _compute_specgram(self, samples, sample_rate): + if self._specgram_type == 'linear': + return self._compute_linear_specgram( + samples, sample_rate, self._stride_ms, self._window_ms, + self._max_freq) + else: + raise ValueError("Unknown specgram_type %s. " + "Supported values: linear." % self._specgram_type) + + def _compute_linear_specgram(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """Laod audio data and calculate the log of spectrogram by FFT. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + specgram, freqs = self._specgram_real( + samples, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + return np.log(specgram[:ind, :] + eps) + + def _specgram_real(self, samples, window_size, stride_size, sample_rate): + """Compute the spectrogram by FFT for a discrete real signal. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + # extract strided windows + truncate_size = (len(samples) - window_size) % stride_size + samples = samples[:len(samples) - truncate_size] + nshape = (window_size, (len(samples) - window_size) // stride_size + 1) + nstrides = (samples.strides[0], samples.strides[0] * stride_size) + windows = np.lib.stride_tricks.as_strided( + samples, shape=nshape, strides=nstrides) + assert np.all( + windows[:, 1] == samples[stride_size:(stride_size + window_size)]) + # window weighting, squared Fast Fourier Transform (fft), scaling + weighting = np.hanning(window_size)[:, None] + fft = np.fft.rfft(windows * weighting, axis=0) + fft = np.absolute(fft)**2 + scale = np.sum(weighting**2) * sample_rate + fft[1:-1, :] *= (2.0 / scale) + fft[(0, -1), :] /= scale + # prepare fft frequency list + freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) + return fft, freqs diff --git a/deep_speech_2/data_utils/featurizer/speech_featurizer.py b/deep_speech_2/data_utils/featurizer/speech_featurizer.py new file mode 100755 index 00000000..06af7a02 --- /dev/null +++ b/deep_speech_2/data_utils/featurizer/speech_featurizer.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.featurizer.audio_featurizer import AudioFeaturizer +from data_utils.featurizer.text_featurizer import TextFeaturizer + + +class SpeechFeaturizer(object): + def __init__(self, + vocab_filepath, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._audio_featurizer = AudioFeaturizer( + specgram_type, stride_ms, window_ms, max_freq, random_seed) + self._text_featurizer = TextFeaturizer(vocab_filepath) + + def featurize(self, speech_segment): + audio_feature = self._audio_featurizer.featurize(speech_segment) + text_ids = self._text_featurizer.text2ids(speech_segment.transcript) + return audio_feature, text_ids + + @property + def vocab_size(self): + return self._text_featurizer.vocab_size + + @property + def vocab_list(self): + return self._text_featurizer.vocab_list diff --git a/deep_speech_2/data_utils/featurizer/text_featurizer.py b/deep_speech_2/data_utils/featurizer/text_featurizer.py new file mode 100755 index 00000000..7e4b69d7 --- /dev/null +++ b/deep_speech_2/data_utils/featurizer/text_featurizer.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + + +class TextFeaturizer(object): + def __init__(self, vocab_filepath): + self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( + vocab_filepath) + + def text2ids(self, text): + tokens = self._char_tokenize(text) + return [self._vocab_dict[token] for token in tokens] + + def ids2text(self, ids): + return ''.join([self._vocab_list[id] for id in ids]) + + @property + def vocab_size(self): + return len(self._vocab_list) + + @property + def vocab_list(self): + return self._vocab_list + + def _char_tokenize(self, text): + return list(text.strip()) + + def _load_vocabulary_from_file(self, vocab_filepath): + """Load vocabulary from file.""" + vocab_lines = [] + with open(vocab_filepath, 'r') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + vocab_dict = dict( + [(token, id) for (id, token) in enumerate(vocab_list)]) + return vocab_dict, vocab_list diff --git a/deep_speech_2/data_utils/normalizer.py b/deep_speech_2/data_utils/normalizer.py new file mode 100755 index 00000000..364600af --- /dev/null +++ b/deep_speech_2/data_utils/normalizer.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +import data_utils.utils as utils +from data_utils.audio import AudioSegment + + +class FeatureNormalizer(object): + def __init__(self, + mean_std_filepath, + manifest_path=None, + featurize_func=None, + num_samples=500, + random_seed=0): + if not mean_std_filepath: + if not (manifest_path and featurize_func): + raise ValueError("If mean_std_filepath is None, meanifest_path " + "and featurize_func should not be None.") + self._rng = random.Random(random_seed) + self._compute_mean_std(manifest_path, featurize_func, num_samples) + else: + self._read_mean_std_from_file(mean_std_filepath) + + def apply(self, features, eps=1e-14): + """Normalize features to be of zero mean and unit stddev.""" + return (features - self._mean) / (self._std + eps) + + def write_to_file(self, filepath): + np.savez(filepath, mean=self._mean, std=self._std) + + def _read_mean_std_from_file(self, filepath): + npzfile = np.load(filepath) + self._mean = npzfile["mean"] + self._std = npzfile["std"] + + def _compute_mean_std(self, manifest_path, featurize_func, num_samples): + manifest = utils.read_manifest(manifest_path) + sampled_manifest = self._rng.sample(manifest, num_samples) + features = [] + for instance in sampled_manifest: + features.append( + featurize_func( + AudioSegment.from_file(instance["audio_filepath"]))) + features = np.hstack(features) + self._mean = np.mean(features, axis=1).reshape([-1, 1]) + self._std = np.std(features, axis=1).reshape([-1, 1]) diff --git a/deep_speech_2/data_utils/utils.py b/deep_speech_2/data_utils/utils.py new file mode 100755 index 00000000..2a916b54 --- /dev/null +++ b/deep_speech_2/data_utils/utils.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + + +def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): + """Load and parse manifest file.""" + manifest = [] + for json_line in open(manifest_path): + try: + json_data = json.loads(json_line) + except Exception as e: + raise IOError("Error reading manifest: %s" % str(e)) + if (json_data["duration"] <= max_duration and + json_data["duration"] >= min_duration): + manifest.append(json_data) + return manifest diff --git a/deep_speech_2/data/librispeech.py b/deep_speech_2/datasets/librispeech/librispeech.py similarity index 99% rename from deep_speech_2/data/librispeech.py rename to deep_speech_2/datasets/librispeech/librispeech.py index 653caa92..1ba2a442 100644 --- a/deep_speech_2/data/librispeech.py +++ b/deep_speech_2/datasets/librispeech/librispeech.py @@ -44,7 +44,7 @@ parser.add_argument( help="Directory to save the dataset. (default: %(default)s)") parser.add_argument( "--manifest_prefix", - default="manifest.libri", + default="manifest", type=str, help="Filepath prefix for output manifests. (default: %(default)s)") parser.add_argument( diff --git a/deep_speech_2/datasets/run_all.sh b/deep_speech_2/datasets/run_all.sh new file mode 100755 index 00000000..ef2b721f --- /dev/null +++ b/deep_speech_2/datasets/run_all.sh @@ -0,0 +1,13 @@ +cd librispeech +python librispeech.py +if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 +fi +cd - + +cat librispeech/manifest.train* | shuf > manifest.train +cat librispeech/manifest.dev-clean > manifest.dev +cat librispeech/manifest.test-clean > manifest.test + +echo "All done." diff --git a/deep_speech_2/data/eng_vocab.txt b/deep_speech_2/datasets/vocab/eng_vocab.txt similarity index 100% rename from deep_speech_2/data/eng_vocab.txt rename to deep_speech_2/datasets/vocab/eng_vocab.txt diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 598c348b..eb31254c 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -2,11 +2,15 @@ Inference for a simplifed version of Baidu DeepSpeech2 model. """ -import paddle.v2 as paddle -import distutils.util +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import argparse import gzip -from audio_data_utils import DataGenerator +import distutils.util +import paddle.v2 as paddle +from data_utils.data import DataGenerator from model import deep_speech2 from decoder import ctc_decode @@ -38,13 +42,13 @@ parser.add_argument( type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', + "--mean_std_filepath", + default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-clean', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( @@ -54,7 +58,7 @@ parser.add_argument( help="Model filepath. (default: %(default)s)") parser.add_argument( "--vocab_filepath", - default='data/eng_vocab.txt', + default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") args = parser.parse_args() @@ -67,28 +71,22 @@ def infer(): # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}') # create network config - dict_size = data_generator.vocabulary_size() - vocab_list = data_generator.vocabulary_list() + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. audio_data = paddle.layer.data( - name="audio_spectrogram", - height=161, - width=2000, - type=paddle.data_type.dense_vector(322000)) + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", - type=paddle.data_type.integer_value_sequence(dict_size)) + type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) output_probs = deep_speech2( audio_data=audio_data, text_data=text_data, - dict_size=dict_size, + dict_size=data_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_size=args.rnn_layer_size, @@ -99,31 +97,30 @@ def infer(): gzip.open(args.model_filepath)) # prepare infer data - feeding = data_generator.data_name_feeding() - test_batch_reader = data_generator.batch_reader_creator( + batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, - padding_to=2000, - flatten=True, - sort_by_duration=False, - shuffle=False) - infer_data = test_batch_reader().next() + sortagrad=False, + batch_shuffle=False) + infer_data = batch_reader().next() # run inference infer_results = paddle.infer( output_layer=output_probs, parameters=parameters, input=infer_data) - num_steps = len(infer_results) / len(infer_data) + num_steps = len(infer_results) // len(infer_data) probs_split = [ infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(0, len(infer_data)) + for i in xrange(len(infer_data)) ] # decode and print for i, probs in enumerate(probs_split): output_transcription = ctc_decode( - probs_seq=probs, vocabulary=vocab_list, method="best_path") + probs_seq=probs, + vocabulary=data_generator.vocab_list, + method="best_path") target_transcription = ''.join( - [vocab_list[index] for index in infer_data[i][1]]) + [data_generator.vocab_list[index] for index in infer_data[i][1]]) print("Target Transcription: %s \nOutput Transcription: %s \n" % (target_transcription, output_transcription)) diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 957c2426..c6aa9752 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -2,21 +2,21 @@ Trainer for a simplifed version of Baidu DeepSpeech2 model. """ -import paddle.v2 as paddle -import distutils.util +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import os import argparse import gzip import time -import sys +import distutils.util +import paddle.v2 as paddle from model import deep_speech2 -from audio_data_utils import DataGenerator -import numpy as np -import os +from data_utils.data import DataGenerator -#TODO: add WER metric - -parser = argparse.ArgumentParser( - description='Simplified version of DeepSpeech2 trainer.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--batch_size", default=32, type=int, help="Minibatch size.") parser.add_argument( @@ -51,7 +51,7 @@ parser.add_argument( help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--use_sortagrad", - default=False, + default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") parser.add_argument( @@ -60,23 +60,23 @@ parser.add_argument( type=int, help="Trainer number. (default: %(default)s)") parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', + "--mean_std_filepath", + default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--train_manifest_path", - default='data/manifest.libri.train-clean-100', + default='datasets/manifest.train', type=str, help="Manifest path for training. (default: %(default)s)") parser.add_argument( "--dev_manifest_path", - default='data/manifest.libri.dev-clean', + default='datasets/manifest.dev', type=str, help="Manifest path for validation. (default: %(default)s)") parser.add_argument( "--vocab_filepath", - default='data/eng_vocab.txt', + default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( @@ -86,6 +86,12 @@ parser.add_argument( help="If set None, the training will start from scratch. " "Otherwise, the training will resume from " "the existing model of this path. (default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='{}', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") args = parser.parse_args() @@ -98,29 +104,26 @@ def train(): def data_generator(): return DataGenerator( vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + mean_std_filepath=args.mean_std_filepath, + augmentation_config=args.augmentation_config) train_generator = data_generator() test_generator = data_generator() + # create network config - dict_size = train_generator.vocabulary_size() # paddle.data_type.dense_array is used for variable batch input. - # the size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be set at each batch. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. audio_data = paddle.layer.data( name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", - type=paddle.data_type.integer_value_sequence(dict_size)) + type=paddle.data_type.integer_value_sequence( + train_generator.vocab_size)) cost = deep_speech2( audio_data=audio_data, text_data=text_data, - dict_size=dict_size, + dict_size=train_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_size=args.rnn_layer_size, @@ -143,13 +146,13 @@ def train(): train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - sortagrad=True if args.init_model_path is None else False, + sortagrad=args.use_sortagrad if args.init_model_path is None else False, batch_shuffle=True) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, + sortagrad=False, batch_shuffle=False) - feeding = train_generator.data_name_feeding() # create event handler def event_handler(event): @@ -158,8 +161,8 @@ def train(): cost_sum += event.cost cost_counter += 1 if event.batch_id % 50 == 0: - print "\nPass: %d, Batch: %d, TrainCost: %f" % ( - event.pass_id, event.batch_id, cost_sum / cost_counter) + print("\nPass: %d, Batch: %d, TrainCost: %f" % + (event.pass_id, event.batch_id, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 with gzip.open("params.tar.gz", 'w') as f: parameters.to_tar(f) @@ -170,16 +173,17 @@ def train(): start_time = time.time() cost_sum, cost_counter = 0.0, 0 if isinstance(event, paddle.event.EndPass): - result = trainer.test(reader=test_batch_reader, feeding=feeding) - print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % ( - time.time() - start_time, event.pass_id, result.cost) + result = trainer.test( + reader=test_batch_reader, feeding=test_generator.feeding) + print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % + (time.time() - start_time, event.pass_id, result.cost)) # run train trainer.train( reader=train_batch_reader, event_handler=event_handler, num_passes=args.num_passes, - feeding=feeding) + feeding=train_generator.feeding) def main(): -- GitLab