diff --git a/.travis/unittest.sh b/.travis/unittest.sh index ad223eb4a9c1f57896762ad38d0b3fa5de5c496b..4195a441eac5f091e49b6203dbd2c637fee6ab69 100755 --- a/.travis/unittest.sh +++ b/.travis/unittest.sh @@ -10,6 +10,7 @@ unittest(){ cd $1 > /dev/null if [ -f "setup.sh" ]; then sh setup.sh + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH fi if [ $? != 0 ]; then exit 1 diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 3010c0e536da732f1c4f042c82badaae21179f87..96fbb7d09aa310003d83a036d301deac54f3004d 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -2,14 +2,19 @@ ## Installation -Please replace `$PADDLE_INSTALL_DIR` with your own paddle installation directory. +### Prerequisites + + - **Python = 2.7** only supported; + - **cuDNN >= 6.0** is required to utilize NVIDIA GPU platform in the installation of PaddlePaddle, and the **CUDA toolkit** with proper version suitable for cuDNN. The cuDNN library below 6.0 is found to yield a fatal error in batch normalization when handling utterances with long duration in inference. + +### Setup ``` sh setup.sh export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH ``` -For some machines, we also need to install libsndfile1. Details to be added. +Please replace `$PADDLE_INSTALL_DIR` with your own paddle installation directory. ## Usage @@ -138,3 +143,28 @@ python tune.py --help ``` Then reset parameters with the tuning result before inference or evaluating. + +### Playing with the ASR Demo + +A real-time ASR demo is built for users to try out the ASR model with their own voice. Please do the following installation on the machine you'd like to run the demo's client (no need for the machine running the demo's server). + +For example, on MAC OS X: + +``` +brew install portaudio +pip install pyaudio +pip install pynput +``` +After a model and language model is prepared, we can first start the demo's server: + +``` +CUDA_VISIBLE_DEVICES=0 python demo_server.py +``` +And then in another console, start the demo's client: + +``` +python demo_client.py +``` +On the client console, press and hold the "white-space" key on the keyboard to start talking, until you finish your speech and then release the "white-space" key. The decoding results (infered transcription) will be displayed. + +It could be possible to start the server and the client in two seperate machines, e.g. `demo_client.py` is usually started in a machine with a microphone hardware, while `demo_server.py` is usually started in a remote server with powerful GPUs. Please first make sure that these two machines have network access to each other, and then use `--host_ip` and `--host_port` to indicate the server machine's actual IP address (instead of the `localhost` as default) and TCP port, in both `demo_server.py` and `demo_client.py`. diff --git a/deep_speech_2/conf/augmentation.config b/deep_speech_2/conf/augmentation.config new file mode 100644 index 0000000000000000000000000000000000000000..6c24da5497460d4bae9c9c4fecdbe96ab8da7532 --- /dev/null +++ b/deep_speech_2/conf/augmentation.config @@ -0,0 +1,8 @@ +[ + { + "type": "shift", + "params": {"min_shift_ms": -5, + "max_shift_ms": 5}, + "prob": 1.0 + } +] diff --git a/deep_speech_2/conf/augmentation.config.example b/deep_speech_2/conf/augmentation.config.example new file mode 100644 index 0000000000000000000000000000000000000000..21ed6ee10375a749f4c072389509db2020d9e9c9 --- /dev/null +++ b/deep_speech_2/conf/augmentation.config.example @@ -0,0 +1,39 @@ +[ + { + "type": "noise", + "params": {"min_snr_dB": 40, + "max_snr_dB": 50, + "noise_manifest_path": "datasets/manifest.noise"}, + "prob": 0.6 + }, + { + "type": "impulse", + "params": {"impulse_manifest_path": "datasets/manifest.impulse"}, + "prob": 0.5 + }, + { + "type": "speed", + "params": {"min_speed_rate": 0.95, + "max_speed_rate": 1.05}, + "prob": 0.5 + }, + { + "type": "shift", + "params": {"min_shift_ms": -5, + "max_shift_ms": 5}, + "prob": 1.0 + }, + { + "type": "volume", + "params": {"min_gain_dBFS": -10, + "max_gain_dBFS": 10}, + "prob": 0.0 + }, + { + "type": "bayesian_normal", + "params": {"target_db": -20, + "prior_db": -20, + "prior_samples": 100}, + "prob": 0.0 + } +] diff --git a/deep_speech_2/data_utils/audio.py b/deep_speech_2/data_utils/audio.py index 3891f5b923f6d73c6b87dcb90bede0183b0e081c..30e25221cd84aa6849061635749188e3bd13d67b 100644 --- a/deep_speech_2/data_utils/audio.py +++ b/deep_speech_2/data_utils/audio.py @@ -204,7 +204,7 @@ class AudioSegment(object): :raise ValueError: If the sample rates of the two segments are not equal, or if the lengths of segments don't match. """ - if type(self) != type(other): + if isinstance(other, type(self)): raise TypeError("Cannot add segments of different types: %s " "and %s." % (type(self), type(other))) if self._sample_rate != other._sample_rate: @@ -231,7 +231,7 @@ class AudioSegment(object): Note that this is an in-place transformation. :param gain: Gain in decibels to apply to samples. - :type gain: float + :type gain: float|1darray """ self._samples *= 10.**(gain / 20.) @@ -457,9 +457,9 @@ class AudioSegment(object): audio segments when resample is not allowed. """ if allow_resample and self.sample_rate != impulse_segment.sample_rate: - impulse_segment = impulse_segment.resample(self.sample_rate) + impulse_segment.resample(self.sample_rate) if self.sample_rate != impulse_segment.sample_rate: - raise ValueError("Impulse segment's sample rate (%d Hz) is not" + raise ValueError("Impulse segment's sample rate (%d Hz) is not " "equal to base signal sample rate (%d Hz)." % (impulse_segment.sample_rate, self.sample_rate)) samples = signal.fftconvolve(self.samples, impulse_segment.samples, diff --git a/deep_speech_2/data_utils/augmentor/augmentation.py b/deep_speech_2/data_utils/augmentor/augmentation.py index 9dced47314a81f52dc0eafd6e592e240953f291d..5c30b627ef9a23ff41d1f64f270934f149a793a2 100644 --- a/deep_speech_2/data_utils/augmentor/augmentation.py +++ b/deep_speech_2/data_utils/augmentor/augmentation.py @@ -8,6 +8,8 @@ import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor +from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor +from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor from data_utils.augmentor.resample import ResampleAugmentor from data_utils.augmentor.online_bayesian_normalization import \ OnlineBayesianNormalizationAugmentor @@ -23,21 +25,46 @@ class AugmentationPipeline(object): string, e.g. .. code-block:: - - '[{"type": "volume", - "params": {"min_gain_dBFS": -15, - "max_gain_dBFS": 15}, - "prob": 0.5}, - {"type": "speed", - "params": {"min_speed_rate": 0.8, - "max_speed_rate": 1.2}, - "prob": 0.5} - ]' + [ { + "type": "noise", + "params": {"min_snr_dB": 10, + "max_snr_dB": 20, + "noise_manifest_path": "datasets/manifest.noise"}, + "prob": 0.0 + }, + { + "type": "speed", + "params": {"min_speed_rate": 0.9, + "max_speed_rate": 1.1}, + "prob": 1.0 + }, + { + "type": "shift", + "params": {"min_shift_ms": -5, + "max_shift_ms": 5}, + "prob": 1.0 + }, + { + "type": "volume", + "params": {"min_gain_dBFS": -10, + "max_gain_dBFS": 10}, + "prob": 0.0 + }, + { + "type": "bayesian_normal", + "params": {"target_db": -20, + "prior_db": -20, + "prior_samples": 100}, + "prob": 0.0 + } + ] + This augmentation configuration inserts two augmentation models into the pipeline, with one is VolumePerturbAugmentor and the other SpeedPerturbAugmentor. "prob" indicates the probability of the current - augmentor to take effect. + augmentor to take effect. If "prob" is zero, the augmentor does not take + effect. :param augmentation_config: Augmentation configuration in json string. :type augmentation_config: str @@ -60,7 +87,7 @@ class AugmentationPipeline(object): :type audio_segment: AudioSegmenet|SpeechSegment """ for augmentor, rate in zip(self._augmentors, self._rates): - if self._rng.uniform(0., 1.) <= rate: + if self._rng.uniform(0., 1.) < rate: augmentor.transform_audio(audio_segment) def _parse_pipeline_from(self, config_json): @@ -89,5 +116,9 @@ class AugmentationPipeline(object): return ResampleAugmentor(self._rng, **params) elif augmentor_type == "bayesian_normal": return OnlineBayesianNormalizationAugmentor(self._rng, **params) + elif augmentor_type == "noise": + return NoisePerturbAugmentor(self._rng, **params) + elif augmentor_type == "impulse": + return ImpulseResponseAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/deep_speech_2/data_utils/augmentor/impulse_response.py b/deep_speech_2/data_utils/augmentor/impulse_response.py new file mode 100644 index 0000000000000000000000000000000000000000..c3de0fdbb2a40150f8cffdef3487ceb4400e52ed --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/impulse_response.py @@ -0,0 +1,35 @@ +"""Contains the impulse response augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase +from data_utils import utils +from data_utils.audio import AudioSegment + + +class ImpulseResponseAugmentor(AugmentorBase): + """Augmentation model for adding impulse response effect. + + :param rng: Random generator object. + :type rng: random.Random + :param impulse_manifest_path: Manifest path for impulse audio data. + :type impulse_manifest_path: basestring + """ + + def __init__(self, rng, impulse_manifest_path): + self._rng = rng + self._impulse_manifest = utils.read_manifest( + manifest_path=impulse_manifest_path) + + def transform_audio(self, audio_segment): + """Add impulse response effect. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + impulse_json = self._rng.sample(self._impulse_manifest, 1)[0] + impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath']) + audio_segment.convolve(impulse_segment, allow_resample=True) diff --git a/deep_speech_2/data_utils/augmentor/noise_perturb.py b/deep_speech_2/data_utils/augmentor/noise_perturb.py new file mode 100644 index 0000000000000000000000000000000000000000..281174af42c2f6d673ead94bd532941769c79c25 --- /dev/null +++ b/deep_speech_2/data_utils/augmentor/noise_perturb.py @@ -0,0 +1,50 @@ +"""Contains the noise perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase +from data_utils import utils +from data_utils.audio import AudioSegment + + +class NoisePerturbAugmentor(AugmentorBase): + """Augmentation model for adding background noise. + + :param rng: Random generator object. + :type rng: random.Random + :param min_snr_dB: Minimal signal noise ratio, in decibels. + :type min_snr_dB: float + :param max_snr_dB: Maximal signal noise ratio, in decibels. + :type max_snr_dB: float + :param noise_manifest_path: Manifest path for noise audio data. + :type noise_manifest_path: basestring + """ + + def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path): + self._min_snr_dB = min_snr_dB + self._max_snr_dB = max_snr_dB + self._rng = rng + self._noise_manifest = utils.read_manifest( + manifest_path=noise_manifest_path) + + def transform_audio(self, audio_segment): + """Add background noise audio. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + noise_json = self._rng.sample(self._noise_manifest, 1)[0] + if noise_json['duration'] < audio_segment.duration: + raise RuntimeError("The duration of sampled noise audio is smaller " + "than the audio segment to add effects to.") + diff_duration = noise_json['duration'] - audio_segment.duration + start = self._rng.uniform(0, diff_duration) + end = start + audio_segment.duration + noise_segment = AudioSegment.slice_from_file( + noise_json['audio_filepath'], start=start, end=end) + snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB) + audio_segment.add_noise( + noise_segment, snr_dB, allow_downsampling=True, rng=self._rng) diff --git a/deep_speech_2/data_utils/augmentor/online_bayesian_normalization.py b/deep_speech_2/data_utils/augmentor/online_bayesian_normalization.py old mode 100755 new mode 100644 diff --git a/deep_speech_2/data_utils/augmentor/resample.py b/deep_speech_2/data_utils/augmentor/resample.py old mode 100755 new mode 100644 diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 5a5fa51b2b91b3fa7bb1a5476d462534978fa973..e7369c057d15152136ab9bc1f41bc4fe0aa0c546 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -72,7 +72,7 @@ class DataGenerator(object): max_freq=None, specgram_type='linear', use_dB_normalization=True, - num_threads=multiprocessing.cpu_count(), + num_threads=multiprocessing.cpu_count() // 2, random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -89,11 +89,27 @@ class DataGenerator(object): self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 - # for caching tar files info self.tar2info = {} self.tar2object = {} + def process_utterance(self, filename, transcript): + """Load, augment, featurize and normalize for speech data. + + :param filename: Audio filepath + :type filename: basestring + :param transcript: Transcription text. + :type transcript: basestring + :return: Tuple of audio feature tensor and list of token ids for + transcription. + :rtype: tuple of (2darray, list) + """ + speech_segment = SpeechSegment.from_file(filename, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram = self._normalizer.apply(specgram) + return specgram, text_ids + def batch_reader_creator(self, manifest_path, batch_size, @@ -163,7 +179,7 @@ class DataGenerator(object): manifest, batch_size, clipped=True) elif shuffle_method == "instance_shuffle": self._rng.shuffle(manifest) - elif not shuffle_method: + elif shuffle_method == None: pass else: raise ValueError("Unknown shuffle method %s." % @@ -263,8 +279,8 @@ class DataGenerator(object): yield instance def mapper(instance): - return self._process_utterance(instance["audio_filepath"], - instance["text"]) + return self.process_utterance(instance["audio_filepath"], + instance["text"]) return paddle.reader.xmap_readers( mapper, reader, self._num_threads, 1024, order=True) diff --git a/deep_speech_2/data_utils/featurizer/audio_featurizer.py b/deep_speech_2/data_utils/featurizer/audio_featurizer.py index 271e535b6a9f1cded27caf4f63adcc51abf3e835..00f0e8a35bc8e67ab285b7d509a0992c02dc54ca 100644 --- a/deep_speech_2/data_utils/featurizer/audio_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/audio_featurizer.py @@ -166,21 +166,18 @@ class AudioFeaturizer(object): "window size.") # compute 13 cepstral coefficients, and the first one is replaced # by log(frame energy) - mfcc_feat = mfcc( - signal=samples, - samplerate=sample_rate, - winlen=0.001 * window_ms, - winstep=0.001 * stride_ms, - highfreq=max_freq) + mfcc_feat = np.transpose( + mfcc( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + highfreq=max_freq)) # Deltas d_mfcc_feat = delta(mfcc_feat, 2) # Deltas-Deltas dd_mfcc_feat = delta(d_mfcc_feat, 2) # concat above three features - concat_mfcc_feat = [ - np.concatenate((mfcc_feat[i], d_mfcc_feat[i], dd_mfcc_feat[i])) - for i in xrange(len(mfcc_feat)) - ] - # transpose to be consistent with the linear specgram situation - concat_mfcc_feat = np.transpose(concat_mfcc_feat) + concat_mfcc_feat = np.concatenate( + (mfcc_feat, d_mfcc_feat, dd_mfcc_feat)) return concat_mfcc_feat diff --git a/deep_speech_2/data_utils/speech.py b/deep_speech_2/data_utils/speech.py index 568e4443ba557149505dfb4de6f230b4962e332a..17d68f315d04b6cc1aae2346df78cf77982cd7bc 100644 --- a/deep_speech_2/data_utils/speech.py +++ b/deep_speech_2/data_utils/speech.py @@ -115,7 +115,7 @@ class SpeechSegment(AudioSegment): speech file. :rtype: SpeechSegment """ - audio = Audiosegment.slice_from_file(filepath, start, end) + audio = AudioSegment.slice_from_file(filepath, start, end) return cls(audio.samples, audio.sample_rate, transcript) @classmethod diff --git a/deep_speech_2/datasets/librispeech/librispeech.py b/deep_speech_2/datasets/librispeech/librispeech.py index 87e52ae4aa286503d79f1326065831acfe6bf985..7e941f0ea7f260680f60dc706fd9873532e3c8bb 100644 --- a/deep_speech_2/datasets/librispeech/librispeech.py +++ b/deep_speech_2/datasets/librispeech/librispeech.py @@ -11,7 +11,7 @@ from __future__ import print_function import distutils.util import os -import wget +import sys import tarfile import argparse import soundfile @@ -66,7 +66,7 @@ def download(url, md5sum, target_dir): filepath = os.path.join(target_dir, url.split("/")[-1]) if not (os.path.exists(filepath) and md5file(filepath) == md5sum): print("Downloading %s ..." % url) - wget.download(url, target_dir) + os.system("wget -c " + url + " -P " + target_dir) print("\nMD5 Chesksum %s ..." % filepath) if not md5file(filepath) == md5sum: raise RuntimeError("MD5 checksum failed.") diff --git a/deep_speech_2/datasets/noise/chime3_background.py b/deep_speech_2/datasets/noise/chime3_background.py new file mode 100644 index 0000000000000000000000000000000000000000..f79ca7335bda7aec795bc43c32a51519f3363d85 --- /dev/null +++ b/deep_speech_2/datasets/noise/chime3_background.py @@ -0,0 +1,128 @@ +"""Prepare CHiME3 background data. + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import os +import wget +import zipfile +import argparse +import soundfile +import json +from paddle.v2.dataset.common import md5file + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ" +MD5 = "c3ff512618d7a67d4f85566ea1bc39ec" + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/chime3_background", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_filepath", + default="manifest.chime3.background", + type=str, + help="Filepath for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def download(url, md5sum, target_dir, filename=None): + """Download file from url to target_dir, and check md5sum.""" + if filename == None: + filename = url.split("/")[-1] + if not os.path.exists(target_dir): os.makedirs(target_dir) + filepath = os.path.join(target_dir, filename) + if not (os.path.exists(filepath) and md5file(filepath) == md5sum): + print("Downloading %s ..." % url) + wget.download(url, target_dir) + print("\nMD5 Chesksum %s ..." % filepath) + if not md5file(filepath) == md5sum: + raise RuntimeError("MD5 checksum failed.") + else: + print("File exists, skip downloading. (%s)" % filepath) + return filepath + + +def unpack(filepath, target_dir): + """Unpack the file to the target_dir.""" + print("Unpacking %s ..." % filepath) + if filepath.endswith('.zip'): + zip = zipfile.ZipFile(filepath, 'r') + zip.extractall(target_dir) + zip.close() + elif filepath.endswith('.tar') or filepath.endswith('.tar.gz'): + tar = zipfile.open(filepath) + tar.extractall(target_dir) + tar.close() + else: + raise ValueError("File format is not supported for unpacking.") + + +def create_manifest(data_dir, manifest_path): + """Create a manifest json file summarizing the data set, with each line + containing the meta data (i.e. audio filepath, transcription text, audio + duration) of each audio file within the data set. + """ + print("Creating manifest %s ..." % manifest_path) + json_lines = [] + for subfolder, _, filelist in sorted(os.walk(data_dir)): + for filename in filelist: + if filename.endswith('.wav'): + filepath = os.path.join(data_dir, subfolder, filename) + audio_data, samplerate = soundfile.read(filepath) + duration = float(len(audio_data)) / samplerate + json_lines.append( + json.dumps({ + 'audio_filepath': filepath, + 'duration': duration, + 'text': '' + })) + with open(manifest_path, 'w') as out_file: + for line in json_lines: + out_file.write(line + '\n') + + +def prepare_chime3(url, md5sum, target_dir, manifest_path): + """Download, unpack and create summmary manifest file.""" + if not os.path.exists(os.path.join(target_dir, "CHiME3")): + # download + filepath = download(url, md5sum, target_dir, + "myairbridge-AG0Y3DNBE5IWRRTV.zip") + # unpack + unpack(filepath, target_dir) + unpack( + os.path.join(target_dir, 'CHiME3_background_bus.zip'), target_dir) + unpack( + os.path.join(target_dir, 'CHiME3_background_caf.zip'), target_dir) + unpack( + os.path.join(target_dir, 'CHiME3_background_ped.zip'), target_dir) + unpack( + os.path.join(target_dir, 'CHiME3_background_str.zip'), target_dir) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + # create manifest json file + create_manifest(target_dir, manifest_path) + + +def main(): + prepare_chime3( + url=URL, + md5sum=MD5, + target_dir=args.target_dir, + manifest_path=args.manifest_filepath) + + +if __name__ == '__main__': + main() diff --git a/deep_speech_2/datasets/run_noise.sh b/deep_speech_2/datasets/run_noise.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b27abde47a97b671609f0cd15e81565b3a00d02 --- /dev/null +++ b/deep_speech_2/datasets/run_noise.sh @@ -0,0 +1,10 @@ +cd noise +python chime3_background.py +if [ $? -ne 0 ]; then + echo "Prepare CHiME3 background noise failed. Terminated." + exit 1 +fi +cd - + +cat noise/manifest.* > manifest.noise +echo "All done." diff --git a/deep_speech_2/decoder.py b/deep_speech_2/decoder.py index a1fadc2c81ac5036f5082e1a60b018106ab90277..8f2e0508de79fea30ebc30230e948b15923bdf24 100644 --- a/deep_speech_2/decoder.py +++ b/deep_speech_2/decoder.py @@ -205,9 +205,9 @@ def ctc_beam_search_decoder_batch(probs_split, :type num_processes: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. + :type cutoff_prob: float :param num_processes: Number of parallel processes. :type num_processes: int - :type cutoff_prob: float :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. diff --git a/deep_speech_2/demo_client.py b/deep_speech_2/demo_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf4dd1bf3f5ea62661e181e0dd2fb3f3b1379c6 --- /dev/null +++ b/deep_speech_2/demo_client.py @@ -0,0 +1,94 @@ +"""Client-end for the ASR demo.""" +from pynput import keyboard +import struct +import socket +import sys +import argparse +import pyaudio + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--host_ip", + default="localhost", + type=str, + help="Server IP address. (default: %(default)s)") +parser.add_argument( + "--host_port", + default=8086, + type=int, + help="Server Port. (default: %(default)s)") +args = parser.parse_args() + +is_recording = False +enable_trigger_record = True + + +def on_press(key): + """On-press keyboard callback function.""" + global is_recording, enable_trigger_record + if key == keyboard.Key.space: + if (not is_recording) and enable_trigger_record: + sys.stdout.write("Start Recording ... ") + sys.stdout.flush() + is_recording = True + + +def on_release(key): + """On-release keyboard callback function.""" + global is_recording, enable_trigger_record + if key == keyboard.Key.esc: + return False + elif key == keyboard.Key.space: + if is_recording == True: + is_recording = False + + +data_list = [] + + +def callback(in_data, frame_count, time_info, status): + """Audio recorder's stream callback function.""" + global data_list, is_recording, enable_trigger_record + if is_recording: + data_list.append(in_data) + enable_trigger_record = False + elif len(data_list) > 0: + # Connect to server and send data + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((args.host_ip, args.host_port)) + sent = ''.join(data_list) + sock.sendall(struct.pack('>i', len(sent)) + sent) + print('Speech[length=%d] Sent.' % len(sent)) + # Receive data from the server and shut down + received = sock.recv(1024) + print "Recognition Results: {}".format(received) + sock.close() + data_list = [] + enable_trigger_record = True + return (in_data, pyaudio.paContinue) + + +def main(): + # prepare audio recorder + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt32, + channels=1, + rate=16000, + input=True, + stream_callback=callback) + stream.start_stream() + + # prepare keyboard listener + with keyboard.Listener( + on_press=on_press, on_release=on_release) as listener: + listener.join() + + # close up + stream.stop_stream() + stream.close() + p.terminate() + + +if __name__ == "__main__": + main() diff --git a/deep_speech_2/demo_server.py b/deep_speech_2/demo_server.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e7e94a450121ea3c5c12fbbf7df4dfa3a48262 --- /dev/null +++ b/deep_speech_2/demo_server.py @@ -0,0 +1,245 @@ +"""Server-end for the ASR demo.""" +import os +import time +import random +import argparse +import distutils.util +from time import gmtime, strftime +import SocketServer +import struct +import wave +import paddle.v2 as paddle +from utils import print_arguments +from data_utils.data import DataGenerator +from model import DeepSpeech2Model +from data_utils.utils import read_manifest + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--host_ip", + default="localhost", + type=str, + help="Server IP address. (default: %(default)s)") +parser.add_argument( + "--host_port", + default=8086, + type=int, + help="Server Port. (default: %(default)s)") +parser.add_argument( + "--speech_save_dir", + default="demo_cache", + type=str, + help="Directory for saving demo speech. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--warmup_manifest_path", + default='datasets/manifest.test', + type=str, + help="Manifest path for warmup test. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='checkpoints/params.latest.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search', + type=str, + help="Method for ctc decoding: best_path or beam_search. " + "(default: %(default)s)") +parser.add_argument( + "--beam_size", + default=100, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="lm/data/common_crawl_00.prune01111.trie.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.36, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.25, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +args = parser.parse_args() + + +class AsrTCPServer(SocketServer.TCPServer): + """The ASR TCP Server.""" + + def __init__(self, + server_address, + RequestHandlerClass, + speech_save_dir, + audio_process_handler, + bind_and_activate=True): + self.speech_save_dir = speech_save_dir + self.audio_process_handler = audio_process_handler + SocketServer.TCPServer.__init__( + self, server_address, RequestHandlerClass, bind_and_activate=True) + + +class AsrRequestHandler(SocketServer.BaseRequestHandler): + """The ASR request handler.""" + + def handle(self): + # receive data through TCP socket + chunk = self.request.recv(1024) + target_len = struct.unpack('>i', chunk[:4])[0] + data = chunk[4:] + while len(data) < target_len: + chunk = self.request.recv(1024) + data += chunk + # write to file + filename = self._write_to_file(data) + + print("Received utterance[length=%d] from %s, saved to %s." % + (len(data), self.client_address[0], filename)) + start_time = time.time() + transcript = self.server.audio_process_handler(filename) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + self.request.sendall(transcript) + + def _write_to_file(self, data): + # prepare save dir and filename + if not os.path.exists(self.server.speech_save_dir): + os.mkdir(self.server.speech_save_dir) + timestamp = strftime("%Y%m%d%H%M%S", gmtime()) + out_filename = os.path.join( + self.server.speech_save_dir, + timestamp + "_" + self.client_address[0] + ".wav") + # write to wav file + file = wave.open(out_filename, 'wb') + file.setnchannels(1) + file.setsampwidth(4) + file.setframerate(16000) + file.writeframes(data) + file.close() + return out_filename + + +def warm_up_test(audio_process_handler, + manifest_path, + num_test_cases, + random_seed=0): + """Warming-up test.""" + manifest = read_manifest(manifest_path) + rng = random.Random(random_seed) + samples = rng.sample(manifest, num_test_cases) + for idx, sample in enumerate(samples): + print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + start_time = time.time() + transcript = audio_process_handler(sample['audio_filepath']) + finish_time = time.time() + print("Response Time: %f, Transcript: %s" % + (finish_time - start_time, transcript)) + + +def start_server(): + """Start the ASR server""" + # prepare data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}', + specgram_type=args.specgram_type, + num_threads=1) + # prepare ASR model + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.model_filepath) + + # prepare ASR inference handler + def file_to_transcript(filename): + feature = data_generator.process_utterance(filename, "") + result_transcript = ds2_model.infer_batch( + infer_data=[feature], + decode_method=args.decode_method, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=1) + return result_transcript[0] + + # warming up with utterrances sampled from Librispeech + print('-----------------------------------------------------------') + print('Warming up ...') + warm_up_test( + audio_process_handler=file_to_transcript, + manifest_path=args.warmup_manifest_path, + num_test_cases=3) + print('-----------------------------------------------------------') + + # start the server + server = AsrTCPServer( + server_address=(args.host_ip, args.host_port), + RequestHandlerClass=AsrRequestHandler, + speech_save_dir=args.speech_save_dir, + audio_process_handler=file_to_transcript) + print("ASR Server Started.") + server.serve_forever() + + +def main(): + print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + start_server() + + +if __name__ == "__main__": + main() diff --git a/deep_speech_2/evaluate.py b/deep_speech_2/evaluate.py index 19eabf4e5aff090ed2f529e3ea3cd7f10ae57cb7..592b7b527a692dd7dfa2a93799828fb77066948c 100644 --- a/deep_speech_2/evaluate.py +++ b/deep_speech_2/evaluate.py @@ -5,20 +5,24 @@ from __future__ import print_function import distutils.util import argparse -import gzip +import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator -from model import deep_speech2 -from decoder import * -from lm.lm_scorer import LmScorer +from model import DeepSpeech2Model from error_rate import wer +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--batch_size", - default=100, + default=128, type=int, help="Minibatch size for evaluation. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -41,12 +45,12 @@ parser.add_argument( help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=multiprocessing.cpu_count(), + default=multiprocessing.cpu_count() // 2, type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--num_processes_beam_search", - default=multiprocessing.cpu_count(), + default=multiprocessing.cpu_count() // 2, type=int, help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( @@ -58,8 +62,8 @@ parser.add_argument( "--decode_method", default='beam_search', type=str, - help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" -) + help="Method for ctc decoding, best_path or beam_search. " + "(default: %(default)s)") parser.add_argument( "--language_model_path", default="lm/data/common_crawl_00.prune01111.trie.klm", @@ -67,12 +71,12 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.26, + default=0.36, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.1, + default=0.25, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( @@ -112,37 +116,12 @@ args = parser.parse_args() def evaluate(): """Evaluate on whole test data for DeepSpeech2.""" - # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_threads_data) - - # create network config - # paddle.data_type.dense_array is used for variable batch input. - # The size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be induced during training. - audio_data = paddle.layer.data( - name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) - output_probs = deep_speech2( - audio_data=audio_data, - text_data=text_data, - dict_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - is_inference=True) - - # load parameters - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.model_filepath)) - - # prepare infer data batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.batch_size, @@ -150,61 +129,39 @@ def evaluate(): sortagrad=False, shuffle_method=None) - # define inferer - inferer = paddle.inference.Inference( - output_layer=output_probs, parameters=parameters) - - # initialize external scorer for beam search decoding - if args.decode_method == 'beam_search': - ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.model_filepath) - wer_counter, wer_sum = 0, 0.0 + wer_sum, num_ins = 0.0, 0 for infer_data in batch_reader(): - # run inference - infer_results = inferer.infer(input=infer_data) - num_steps = len(infer_results) // len(infer_data) - probs_split = [ - infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(0, len(infer_data)) - ] - # target transcription - target_transcription = [ - ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) for i, probs in enumerate(probs_split) + result_transcripts = ds2_model.infer_batch( + infer_data=infer_data, + decode_method=args.decode_method, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=args.num_processes_beam_search) + target_transcripts = [ + ''.join([data_generator.vocab_list[token] for token in transcript]) + for _, transcript in infer_data ] - # decode and print - # best path decode - if args.decode_method == "best_path": - for i, probs in enumerate(probs_split): - output_transcription = ctc_best_path_decoder( - probs_seq=probs, vocabulary=data_generator.vocab_list) - wer_sum += wer(target_transcription[i], output_transcription) - wer_counter += 1 - # beam search decode - elif args.decode_method == "beam_search": - # beam search using multiple processes - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - num_processes=args.num_processes_beam_search, - ext_scoring_func=ext_scorer, - cutoff_prob=args.cutoff_prob, ) - for i, beam_search_result in enumerate(beam_search_results): - wer_sum += wer(target_transcription[i], - beam_search_result[0][1]) - wer_counter += 1 - else: - raise ValueError("Decoding method [%s] is not supported." % - decode_method) - - print("Final WER = %f" % (wer_sum / wer_counter)) + for target, result in zip(target_transcripts, result_transcripts): + wer_sum += wer(target, result) + num_ins += 1 + print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins)) + print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins)) def main(): - paddle.init(use_gpu=args.use_gpu, trainer_count=1) + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) evaluate() diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 817526302764b3d6044688da97ad0cc072c14144..df5953e59fd198951babeabe3fa27ae680ef0ad6 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -4,14 +4,11 @@ from __future__ import division from __future__ import print_function import argparse -import gzip import distutils.util import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator -from model import deep_speech2 -from decoder import * -from lm.lm_scorer import LmScorer +from model import DeepSpeech2Model from error_rate import wer import utils @@ -43,12 +40,12 @@ parser.add_argument( help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=multiprocessing.cpu_count(), + default=1, type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--num_processes_beam_search", - default=multiprocessing.cpu_count(), + default=multiprocessing.cpu_count() // 2, type=int, help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( @@ -57,6 +54,11 @@ parser.add_argument( type=str, help="Feature type of audio data: 'linear' (power spectrum)" " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -81,18 +83,13 @@ parser.add_argument( "--decode_method", default='beam_search', type=str, - help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" -) + help="Method for ctc decoding: best_path or beam_search. " + "(default: %(default)s)") parser.add_argument( "--beam_size", default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--num_results_per_sample", - default=1, - type=int, - help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", default="lm/data/common_crawl_00.prune01111.trie.klm", @@ -100,12 +97,12 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.26, + default=0.36, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.1, + default=0.25, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( @@ -119,37 +116,12 @@ args = parser.parse_args() def infer(): """Inference for DeepSpeech2.""" - # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_threads_data) - - # create network config - # paddle.data_type.dense_array is used for variable batch input. - # The size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be induced during training. - audio_data = paddle.layer.data( - name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) - output_probs = deep_speech2( - audio_data=audio_data, - text_data=text_data, - dict_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - is_inference=True) - - # load parameters - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.model_filepath)) - - # prepare infer data batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, @@ -158,66 +130,36 @@ def infer(): shuffle_method=None) infer_data = batch_reader().next() - # run inference - infer_results = paddle.infer( - output_layer=output_probs, parameters=parameters, input=infer_data) - num_steps = len(infer_results) // len(infer_data) - probs_split = [ - infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(len(infer_data)) - ] + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.model_filepath) + result_transcripts = ds2_model.infer_batch( + infer_data=infer_data, + decode_method=args.decode_method, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=args.num_processes_beam_search) - # targe transcription - target_transcription = [ - ''.join( - [data_generator.vocab_list[index] for index in infer_data[i][1]]) - for i, probs in enumerate(probs_split) + target_transcripts = [ + ''.join([data_generator.vocab_list[token] for token in transcript]) + for _, transcript in infer_data ] - - ## decode and print - # best path decode - wer_sum, wer_counter = 0, 0 - if args.decode_method == "best_path": - for i, probs in enumerate(probs_split): - best_path_transcription = ctc_best_path_decoder( - probs_seq=probs, vocabulary=data_generator.vocab_list) - print("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target_transcription[i], best_path_transcription)) - wer_cur = wer(target_transcription[i], best_path_transcription) - wer_sum += wer_cur - wer_counter += 1 - print("cur wer = %f, average wer = %f" % - (wer_cur, wer_sum / wer_counter)) - # beam search decode - elif args.decode_method == "beam_search": - ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) - beam_search_batch_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - num_processes=args.num_processes_beam_search, - cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_batch_results): - print("\nTarget Transcription:\t%s" % target_transcription[i]) - for index in xrange(args.num_results_per_sample): - result = beam_search_result[index] - #output: index, log prob, beam result - print("Beam %d: %f \t%s" % (index, result[0], result[1])) - wer_cur = wer(target_transcription[i], beam_search_result[0][1]) - wer_sum += wer_cur - wer_counter += 1 - print("cur wer = %f , average wer = %f" % - (wer_cur, wer_sum / wer_counter)) - else: - raise ValueError("Decoding method [%s] is not supported." % - decode_method) + for target, result in zip(target_transcripts, result_transcripts): + print("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + print("Current wer = %f" % wer(target, result)) def main(): utils.print_arguments(args) - paddle.init(use_gpu=args.use_gpu, trainer_count=1) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) infer() diff --git a/deep_speech_2/layer.py b/deep_speech_2/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b492645d5a42f3f0c61d2646b7d6a19bb0c3e98 --- /dev/null +++ b/deep_speech_2/layer.py @@ -0,0 +1,177 @@ +"""Contains DeepSpeech2 layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.v2 as paddle + + +def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, + padding, act): + """Convolution layer with batch normalization. + + :param input: Input layer. + :type input: LayerOutput + :param filter_size: The x dimension of a filter kernel. Or input a tuple for + two image dimension. + :type filter_size: int|tuple|list + :param num_channels_in: Number of input channels. + :type num_channels_in: int + :type num_channels_out: Number of output channels. + :type num_channels_in: out + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension. + :type padding: int|tuple|list + :param act: Activation type. + :type act: BaseActivation + :return: Batch norm layer after convolution layer. + :rtype: LayerOutput + """ + conv_layer = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=num_channels_in, + num_filters=num_channels_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=conv_layer, act=act) + + +def bidirectional_simple_rnn_bn_layer(name, input, size, act): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: LayerOutput + :param size: Number of RNN cells. + :type size: int + :param act: Activation type. + :type act: BaseActivation + :return: Bidirectional simple rnn layer. + :rtype: LayerOutput + """ + # input-hidden weights shared across bi-direcitonal rnn. + input_proj = paddle.layer.fc( + input=input, size=size, act=paddle.activation.Linear(), bias_attr=False) + # batch norm is only performed on input-state projection + input_proj_bn = paddle.layer.batch_norm( + input=input_proj, act=paddle.activation.Linear()) + # forward and backward in time + forward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=False) + backward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=True) + return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn]) + + +def conv_group(input, num_stacks): + """Convolution group with stacked convolution layers. + + :param input: Input layer. + :type input: LayerOutput + :param num_stacks: Number of stacked convolution layers. + :type num_stacks: int + :return: Output layer of the convolution group. + :rtype: LayerOutput + """ + conv = conv_bn_layer( + input=input, + filter_size=(11, 41), + num_channels_in=1, + num_channels_out=32, + stride=(3, 2), + padding=(5, 20), + act=paddle.activation.BRelu()) + for i in xrange(num_stacks - 1): + conv = conv_bn_layer( + input=conv, + filter_size=(11, 21), + num_channels_in=32, + num_channels_out=32, + stride=(1, 2), + padding=(5, 10), + act=paddle.activation.BRelu()) + output_num_channels = 32 + output_height = 160 // pow(2, num_stacks) + 1 + return conv, output_num_channels, output_height + + +def rnn_group(input, size, num_stacks): + """RNN group with stacked bidirectional simple RNN layers. + + :param input: Input layer. + :type input: LayerOutput + :param size: Number of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :return: Output layer of the RNN group. + :rtype: LayerOutput + """ + output = input + for i in xrange(num_stacks): + output = bidirectional_simple_rnn_bn_layer( + name=str(i), input=output, size=size, act=paddle.activation.BRelu()) + return output + + +def deep_speech2(audio_data, + text_data, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=256): + """ + The whole DeepSpeech2 model structure (a simplified version). + + :param audio_data: Audio spectrogram data layer. + :type audio_data: LayerOutput + :param text_data: Transcription text data layer. + :type text_data: LayerOutput + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (number of RNN cells). + :type rnn_size: int + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + # convolution group + conv_group_output, conv_group_num_channels, conv_group_height = conv_group( + input=audio_data, num_stacks=num_conv_layers) + # convert data form convolution feature map to sequence of vectors + conv2seq = paddle.layer.block_expand( + input=conv_group_output, + num_channels=conv_group_num_channels, + stride_x=1, + stride_y=1, + block_x=1, + block_y=conv_group_height) + # rnn group + rnn_group_output = rnn_group( + input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) + fc = paddle.layer.fc( + input=rnn_group_output, + size=dict_size + 1, + act=paddle.activation.Linear(), + bias_attr=True) + # probability distribution with softmax + log_probs = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=fc), + act=paddle.activation.Softmax()) + # ctc cost + ctc_loss = paddle.layer.warp_ctc( + input=fc, + label=text_data, + size=dict_size + 1, + blank=dict_size, + norm_by_times=True) + return log_probs, ctc_loss diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py index cb0b4ecbba1a3fb435a5f625a54d6e5bebe689e0..2eb7c3594974239dff68f771e478423414688411 100644 --- a/deep_speech_2/model.py +++ b/deep_speech_2/model.py @@ -3,141 +3,240 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import os +import time +import gzip +from decoder import * +from lm.lm_scorer import LmScorer import paddle.v2 as paddle +from layer import * -def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, - padding, act): - """ - Convolution layer with batch normalization. - """ - conv_layer = paddle.layer.img_conv( - input=input, - filter_size=filter_size, - num_channels=num_channels_in, - num_filters=num_channels_out, - stride=stride, - padding=padding, - act=paddle.activation.Linear(), - bias_attr=False) - return paddle.layer.batch_norm(input=conv_layer, act=act) - - -def bidirectional_simple_rnn_bn_layer(name, input, size, act): - """ - Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - """ - # input-hidden weights shared across bi-direcitonal rnn. - input_proj = paddle.layer.fc( - input=input, size=size, act=paddle.activation.Linear(), bias_attr=False) - # batch norm is only performed on input-state projection - input_proj_bn = paddle.layer.batch_norm( - input=input_proj, act=paddle.activation.Linear()) - # forward and backward in time - forward_simple_rnn = paddle.layer.recurrent( - input=input_proj_bn, act=act, reverse=False) - backward_simple_rnn = paddle.layer.recurrent( - input=input_proj_bn, act=act, reverse=True) - return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn]) - - -def conv_group(input, num_stacks): - """ - Convolution group with several stacking convolution layers. - """ - conv = conv_bn_layer( - input=input, - filter_size=(11, 41), - num_channels_in=1, - num_channels_out=32, - stride=(3, 2), - padding=(5, 20), - act=paddle.activation.BRelu()) - for i in xrange(num_stacks - 1): - conv = conv_bn_layer( - input=conv, - filter_size=(11, 21), - num_channels_in=32, - num_channels_out=32, - stride=(1, 2), - padding=(5, 10), - act=paddle.activation.BRelu()) - output_num_channels = 32 - output_height = 160 // pow(2, num_stacks) + 1 - return conv, output_num_channels, output_height - - -def rnn_group(input, size, num_stacks): - """ - RNN group with several stacking RNN layers. - """ - output = input - for i in xrange(num_stacks): - output = bidirectional_simple_rnn_bn_layer( - name=str(i), input=output, size=size, act=paddle.activation.BRelu()) - return output - - -def deep_speech2(audio_data, - text_data, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=256, - is_inference=False): - """ - The whole DeepSpeech2 model structure (a simplified version). - - :param audio_data: Audio spectrogram data layer. - :type audio_data: LayerOutput - :param text_data: Transcription text data layer. - :type text_data: LayerOutput - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int +class DeepSpeech2Model(object): + """DeepSpeech2Model class. + + :param vocab_size: Decoding vocabulary size. + :type vocab_size: int :param num_conv_layers: Number of stacking convolution layers. :type num_conv_layers: int :param num_rnn_layers: Number of stacking RNN layers. :type num_rnn_layers: int - :param rnn_size: RNN layer size (number of RNN cells). - :type rnn_size: int - :param is_inference: False in the training mode, and True in the - inferene mode. - :type is_inference: bool - :return: If is_inference set False, return a ctc cost layer; - if is_inference set True, return a sequence layer of output - probability distribution. - :rtype: tuple of LayerOutput + :param rnn_layer_size: RNN layer size (number of RNN cells). + :type rnn_layer_size: int + :param pretrained_model_path: Pretrained model path. If None, will train + from stratch. + :type pretrained_model_path: basestring|None """ - # convolution group - conv_group_output, conv_group_num_channels, conv_group_height = conv_group( - input=audio_data, num_stacks=num_conv_layers) - # convert data form convolution feature map to sequence of vectors - conv2seq = paddle.layer.block_expand( - input=conv_group_output, - num_channels=conv_group_num_channels, - stride_x=1, - stride_y=1, - block_x=1, - block_y=conv_group_height) - # rnn group - rnn_group_output = rnn_group( - input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) - fc = paddle.layer.fc( - input=rnn_group_output, - size=dict_size + 1, - act=paddle.activation.Linear(), - bias_attr=True) - if is_inference: - # probability distribution with softmax - return paddle.layer.mixed( - input=paddle.layer.identity_projection(input=fc), - act=paddle.activation.Softmax()) - else: - # ctc cost - return paddle.layer.warp_ctc( - input=fc, - label=text_data, - size=dict_size + 1, - blank=dict_size, - norm_by_times=True) + + def __init__(self, vocab_size, num_conv_layers, num_rnn_layers, + rnn_layer_size, pretrained_model_path): + self._create_network(vocab_size, num_conv_layers, num_rnn_layers, + rnn_layer_size) + self._create_parameters(pretrained_model_path) + self._inferer = None + self._loss_inferer = None + self._ext_scorer = None + + def train(self, + train_batch_reader, + dev_batch_reader, + feeding_dict, + learning_rate, + gradient_clipping, + num_passes, + output_model_dir, + num_iterations_print=100): + """Train the model. + + :param train_batch_reader: Train data reader. + :type train_batch_reader: callable + :param dev_batch_reader: Validation data reader. + :type dev_batch_reader: callable + :param feeding_dict: Feeding is a map of field name and tuple index + of the data that reader returns. + :type feeding_dict: dict|list + :param learning_rate: Learning rate for ADAM optimizer. + :type learning_rate: float + :param gradient_clipping: Gradient clipping threshold. + :type gradient_clipping: float + :param num_passes: Number of training epochs. + :type num_passes: int + :param num_iterations_print: Number of training iterations for printing + a training loss. + :type rnn_iteratons_print: int + :param output_model_dir: Directory for saving the model (every pass). + :type output_model_dir: basestring + """ + # prepare model output directory + if not os.path.exists(output_model_dir): + os.mkdir(output_model_dir) + + # prepare optimizer and trainer + optimizer = paddle.optimizer.Adam( + learning_rate=learning_rate, + gradient_clipping_threshold=gradient_clipping) + trainer = paddle.trainer.SGD( + cost=self._loss, + parameters=self._parameters, + update_equation=optimizer) + + # create event handler + def event_handler(event): + global start_time, cost_sum, cost_counter + if isinstance(event, paddle.event.EndIteration): + cost_sum += event.cost + cost_counter += 1 + if (event.batch_id + 1) % num_iterations_print == 0: + output_model_path = os.path.join(output_model_dir, + "params.latest.tar.gz") + with gzip.open(output_model_path, 'w') as f: + self._parameters.to_tar(f) + print("\nPass: %d, Batch: %d, TrainCost: %f" % + (event.pass_id, event.batch_id + 1, + cost_sum / cost_counter)) + cost_sum, cost_counter = 0.0, 0 + else: + sys.stdout.write('.') + sys.stdout.flush() + if isinstance(event, paddle.event.BeginPass): + start_time = time.time() + cost_sum, cost_counter = 0.0, 0 + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + reader=dev_batch_reader, feeding=feeding_dict) + output_model_path = os.path.join( + output_model_dir, "params.pass-%d.tar.gz" % event.pass_id) + with gzip.open(output_model_path, 'w') as f: + self._parameters.to_tar(f) + print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % + (time.time() - start_time, event.pass_id, result.cost)) + + # run train + trainer.train( + reader=train_batch_reader, + event_handler=event_handler, + num_passes=num_passes, + feeding=feeding_dict) + + def infer_loss_batch(self, infer_data): + """Model inference. Infer the ctc loss for a batch of speech + utterances. + + :param infer_data: List of utterances to infer, with each utterance a + tuple of audio features and transcription text (empty + string). + :type infer_data: list + :return: List of ctc loss. + :rtype: List of float + """ + # define inferer + if self._loss_inferer == None: + self._loss_inferer = paddle.inference.Inference( + output_layer=self._loss, parameters=self._parameters) + # run inference + return self._loss_inferer.infer(input=infer_data) + + def infer_batch(self, infer_data, decode_method, beam_alpha, beam_beta, + beam_size, cutoff_prob, vocab_list, language_model_path, + num_processes): + """Model inference. Infer the transcription for a batch of speech + utterances. + + :param infer_data: List of utterances to infer, with each utterance + consisting of a tuple of audio features and + transcription text (empty string). + :type infer_data: list + :param decode_method: Decoding method name, 'best_path' or + 'beam search'. + :param decode_method: string + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param language_model_path: Filepath for language model. + :type language_model_path: basestring|None + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of basestring + """ + # define inferer + if self._inferer == None: + self._inferer = paddle.inference.Inference( + output_layer=self._log_probs, parameters=self._parameters) + # run inference + infer_results = self._inferer.infer(input=infer_data) + num_steps = len(infer_results) // len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + # run decoder + results = [] + if decode_method == "best_path": + # best path decode + for i, probs in enumerate(probs_split): + output_transcription = ctc_best_path_decoder( + probs_seq=probs, vocabulary=data_generator.vocab_list) + results.append(output_transcription) + elif decode_method == "beam_search": + # initialize external scorer + if self._ext_scorer == None: + self._ext_scorer = LmScorer(beam_alpha, beam_beta, + language_model_path) + self._loaded_lm_path = language_model_path + else: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + assert self._loaded_lm_path == language_model_path + + # beam search decode + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + blank_id=len(vocab_list), + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob) + + results = [result[0][1] for result in beam_search_results] + else: + raise ValueError("Decoding method [%s] is not supported." % + decode_method) + return results + + def _create_parameters(self, model_path=None): + """Load or create model parameters.""" + if model_path is None: + self._parameters = paddle.parameters.create(self._loss) + else: + self._parameters = paddle.parameters.Parameters.from_tar( + gzip.open(model_path)) + + def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers, + rnn_layer_size): + """Create data layers and model network.""" + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. + audio_data = paddle.layer.data( + name="audio_spectrogram", + type=paddle.data_type.dense_array(161 * 161)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(vocab_size)) + self._log_probs, self._loss = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=vocab_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_layer_size) diff --git a/deep_speech_2/requirements.txt b/deep_speech_2/requirements.txt old mode 100755 new mode 100644 index 721fa2811081e530a9cec3b2e403ad2372b59269..131f75ff47e003f3b44f4a62f1431cf13d4f44a4 --- a/deep_speech_2/requirements.txt +++ b/deep_speech_2/requirements.txt @@ -1,5 +1,5 @@ -wget==3.2 scipy==0.13.1 resampy==0.1.5 -https://github.com/kpu/kenlm/archive/master.zip +SoundFile==0.9.0.post1 python_speech_features +https://github.com/luotao1/kenlm/archive/master.zip diff --git a/deep_speech_2/setup.sh b/deep_speech_2/setup.sh index 8cba91ecdb68b42125181331471f9ee323062a24..7f4272550c4efb9cebd5483c4911caed02cd9673 100644 --- a/deep_speech_2/setup.sh +++ b/deep_speech_2/setup.sh @@ -9,25 +9,21 @@ if [ $? != 0 ]; then exit 1 fi -# install package Soundfile -curl -O "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" +# install package libsndfile +python -c "import soundfile" if [ $? != 0 ]; then - echo "Download libsndfile-1.0.28.tar.gz failed !!!" - exit 1 + echo "Install package libsndfile into default system path." + curl -O "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" + if [ $? != 0 ]; then + echo "Download libsndfile-1.0.28.tar.gz failed !!!" + exit 1 + fi + tar -zxvf libsndfile-1.0.28.tar.gz + cd libsndfile-1.0.28 + ./configure && make && make install + cd .. + rm -rf libsndfile-1.0.28 + rm libsndfile-1.0.28.tar.gz fi -tar -zxvf libsndfile-1.0.28.tar.gz -cd libsndfile-1.0.28 -./configure && make && make install -cd - -rm -rf libsndfile-1.0.28 -rm libsndfile-1.0.28.tar.gz -pip install SoundFile==0.9.0.post1 -if [ $? != 0 ]; then - echo "Install SoundFile failed !!!" - exit 1 -fi - -# prepare ./checkpoints -mkdir checkpoints echo "Install all dependencies successfully." diff --git a/deep_speech_2/tests/test_setup.py b/deep_speech_2/tests/test_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..18b9c1a0ce5333f559383b18704edf7270457fcf --- /dev/null +++ b/deep_speech_2/tests/test_setup.py @@ -0,0 +1,23 @@ +"""Test Setup.""" +import unittest +import numpy as np +import os + + +class TestSetup(unittest.TestCase): + def test_soundfile(self): + import soundfile as sf + # floating point data is typically limited to the interval [-1.0, 1.0], + # but smaller/larger values are supported as well + data = np.array([[1.75, -1.75], [1.0, -1.0], [0.5, -0.5], + [0.25, -0.25]]) + file = 'test.wav' + sf.write(file, data, 44100, format='WAV', subtype='FLOAT') + read, fs = sf.read(file) + self.assertTrue(np.all(read == data)) + self.assertEqual(fs, 44100) + os.remove(file) + + +if __name__ == '__main__': + unittest.main() diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 6481074c6e58f98f57f81c6e42480fa00a261bbe..0d4e2508dddf5cc6834b4f61f0c2cc8deee405af 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -3,15 +3,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys -import os import argparse -import gzip -import time import distutils.util import multiprocessing import paddle.v2 as paddle -from model import deep_speech2 +from model import DeepSpeech2Model from data_utils.data import DataGenerator import utils @@ -23,6 +19,12 @@ parser.add_argument( default=200, type=int, help="Training pass number. (default: %(default)s)") +parser.add_argument( + "--num_iterations_print", + default=100, + type=int, + help="Number of iterations for every train cost printing. " + "(default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -84,7 +86,7 @@ parser.add_argument( help="Trainer number. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=multiprocessing.cpu_count(), + default=multiprocessing.cpu_count() // 2, type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( @@ -114,11 +116,14 @@ parser.add_argument( help="If set None, the training will start from scratch. " "Otherwise, the training will resume from " "the existing model of this path. (default: %(default)s)") +parser.add_argument( + "--output_model_dir", + default="./checkpoints", + type=str, + help="Directory for saving models. (default: %(default)s)") parser.add_argument( "--augmentation_config", - default='[{"type": "shift", ' - '"params": {"min_shift_ms": -5, "max_shift_ms": 5},' - '"prob": 1.0}]', + default=open('conf/augmentation.config', 'r').read(), type=str, help="Augmentation configuration in json-format. " "(default: %(default)s)") @@ -127,100 +132,48 @@ args = parser.parse_args() def train(): """DeepSpeech2 training.""" - - # initialize data generator - def data_generator(): - return DataGenerator( - vocab_filepath=args.vocab_filepath, - mean_std_filepath=args.mean_std_filepath, - augmentation_config=args.augmentation_config, - max_duration=args.max_duration, - min_duration=args.min_duration, - specgram_type=args.specgram_type, - num_threads=args.num_threads_data) - - train_generator = data_generator() - test_generator = data_generator() - - # create network config - # paddle.data_type.dense_array is used for variable batch input. - # The size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be induced during training. - audio_data = paddle.layer.data( - name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence( - train_generator.vocab_size)) - cost = deep_speech2( - audio_data=audio_data, - text_data=text_data, - dict_size=train_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - is_inference=False) - - # create/load parameters and optimizer - if args.init_model_path is None: - parameters = paddle.parameters.create(cost) - else: - if not os.path.isfile(args.init_model_path): - raise IOError("Invalid model!") - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.init_model_path)) - optimizer = paddle.optimizer.Adam( - learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) - trainer = paddle.trainer.SGD( - cost=cost, parameters=parameters, update_equation=optimizer) - - # prepare data reader + train_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config=args.augmentation_config, + max_duration=args.max_duration, + min_duration=args.min_duration, + specgram_type=args.specgram_type, + num_threads=args.num_threads_data) + dev_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config="{}", + specgram_type=args.specgram_type, + num_threads=args.num_threads_data) train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, min_batch_size=args.trainer_count, sortagrad=args.use_sortagrad if args.init_model_path is None else False, shuffle_method=args.shuffle_method) - test_batch_reader = test_generator.batch_reader_creator( + dev_batch_reader = dev_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, min_batch_size=1, # must be 1, but will have errors. sortagrad=False, shuffle_method=None) - # create event handler - def event_handler(event): - global start_time, cost_sum, cost_counter - if isinstance(event, paddle.event.EndIteration): - cost_sum += event.cost - cost_counter += 1 - if (event.batch_id + 1) % 100 == 0: - print("\nPass: %d, Batch: %d, TrainCost: %f" % ( - event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) - cost_sum, cost_counter = 0.0, 0 - with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: - parameters.to_tar(f) - else: - sys.stdout.write('.') - sys.stdout.flush() - if isinstance(event, paddle.event.BeginPass): - start_time = time.time() - cost_sum, cost_counter = 0.0, 0 - if isinstance(event, paddle.event.EndPass): - result = trainer.test( - reader=test_batch_reader, feeding=test_generator.feeding) - print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % - (time.time() - start_time, event.pass_id, result.cost)) - with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, - 'w') as f: - parameters.to_tar(f) - - # run train - trainer.train( - reader=train_batch_reader, - event_handler=event_handler, + ds2_model = DeepSpeech2Model( + vocab_size=train_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.init_model_path) + ds2_model.train( + train_batch_reader=train_batch_reader, + dev_batch_reader=dev_batch_reader, + feeding_dict=train_generator.feeding, + learning_rate=args.adam_learning_rate, + gradient_clipping=400, num_passes=args.num_passes, - feeding=train_generator.feeding) + num_iterations_print=args.num_iterations_print, + output_model_dir=args.output_model_dir) def main(): diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index 2fcca48628aa0aba7fd2e09a1d9ba90582492f89..328d67a1197634e5f02ad0689056196a8904fc06 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -3,14 +3,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import distutils.util import argparse -import gzip +import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator -from model import deep_speech2 -from decoder import * -from lm.lm_scorer import LmScorer +from model import DeepSpeech2Model from error_rate import wer import utils @@ -40,14 +39,19 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=multiprocessing.cpu_count(), + default=1, type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--num_processes_beam_search", - default=multiprocessing.cpu_count(), + default=multiprocessing.cpu_count() // 2, type=int, help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( @@ -62,10 +66,10 @@ parser.add_argument( type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( - "--decode_manifest_path", - default='datasets/manifest.test', + "--tune_manifest_path", + default='datasets/manifest.dev', type=str, - help="Manifest path for decoding. (default: %(default)s)") + help="Manifest path for tuning. (default: %(default)s)") parser.add_argument( "--model_filepath", default='checkpoints/params.latest.tar.gz', @@ -127,96 +131,64 @@ args = parser.parse_args() def tune(): """Tune parameters alpha and beta on one minibatch.""" - if not args.num_alphas >= 0: raise ValueError("num_alphas must be non-negative!") - if not args.num_betas >= 0: raise ValueError("num_betas must be non-negative!") - # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_threads_data) - - # create network config - # paddle.data_type.dense_array is used for variable batch input. - # The size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be induced during training. - audio_data = paddle.layer.data( - name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) - output_probs = deep_speech2( - audio_data=audio_data, - text_data=text_data, - dict_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - is_inference=True) - - # load parameters - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.model_filepath)) - - # prepare infer data batch_reader = data_generator.batch_reader_creator( - manifest_path=args.decode_manifest_path, + manifest_path=args.tune_manifest_path, batch_size=args.num_samples, sortagrad=False, shuffle_method=None) - # get one batch data for tuning - infer_data = batch_reader().next() - - # run inference - infer_results = paddle.infer( - output_layer=output_probs, parameters=parameters, input=infer_data) - num_steps = len(infer_results) // len(infer_data) - probs_split = [ - infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(0, len(infer_data)) + tune_data = batch_reader().next() + target_transcripts = [ + ''.join([data_generator.vocab_list[token] for token in transcript]) + for _, transcript in tune_data ] + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + pretrained_model_path=args.model_filepath) + # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] - ext_scorer = LmScorer(args.alpha_from, args.beta_from, - args.language_model_path) ## tune parameters in loop for alpha, beta in params_grid: - wer_sum, wer_counter = 0, 0 - # reset scorer - ext_scorer.reset_params(alpha, beta) - # beam search using multiple processes - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=data_generator.vocab_list, + result_transcripts = ds2_model.infer_batch( + infer_data=tune_data, + decode_method='beam_search', + beam_alpha=alpha, + beam_beta=beta, beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, - blank_id=len(data_generator.vocab_list), - num_processes=args.num_processes_beam_search, - ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 - + vocab_list=data_generator.vocab_list, + language_model_path=args.language_model_path, + num_processes=args.num_processes_beam_search) + wer_sum, num_ins = 0.0, 0 + for target, result in zip(target_transcripts, result_transcripts): + wer_sum += wer(target, result) + num_ins += 1 print("alpha = %f\tbeta = %f\tWER = %f" % - (alpha, beta, wer_sum / wer_counter)) + (alpha, beta, wer_sum / num_ins)) def main(): - paddle.init(use_gpu=args.use_gpu, trainer_count=1) + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) tune()