diff --git a/README.md b/README.md
index 96fbb7d09aa310003d83a036d301deac54f3004d..8b0da1ae2bb83fe12669654afac9d65248ae0c0a 100644
--- a/README.md
+++ b/README.md
@@ -1,170 +1,486 @@
-# Deep Speech 2 on PaddlePaddle
+# DeepSpeech2 on PaddlePaddle
+
+*DeepSpeech2 on PaddlePaddle* is an open-source implementation of end-to-end Automatic Speech Recognition (ASR) engine, based on [Baidu's Deep Speech 2 paper](http://proceedings.mlr.press/v48/amodei16.pdf), with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient and scalable implementation, including training, inference & testing module, distributed [PaddleCloud](https://github.com/PaddlePaddle/cloud) training, and demo deployment. Besides, several pre-trained models for both English and Mandarin are also released.
+
+## Table of Contents
+- [Prerequisites](#prerequisites)
+- [Installation](#installation)
+- [Getting Started](#getting-started)
+- [Data Preparation](#data-preparation)
+- [Training a Model](#training-a-model)
+- [Data Augmentation Pipeline](#data-augmentation-pipeline)
+- [Inference and Evaluation](#inference-and-evaluation)
+- [Distributed Cloud Training](#distributed-cloud-training)
+- [Hyper-parameters Tuning](#hyper-parameters-tuning)
+- [Training for Mandarin Language](#training-for-mandarin-language)
+- [Trying Live Demo with Your Own Voice](#trying-live-demo-with-your-own-voice)
+- [Released Models](#released-models)
+- [Experiments and Benchmarks](#experiments-and-benchmarks)
+- [Questions and Help](#questions-and-help)
+
+## Prerequisites
+- Python 2.7 only supported
+- PaddlePaddle the latest version (please refer to the [Installation Guide](https://github.com/PaddlePaddle/Paddle#installation))
## Installation
-### Prerequisites
+Please make sure the above [prerequisites](#prerequisites) have been satisfied before moving on.
- - **Python = 2.7** only supported;
- - **cuDNN >= 6.0** is required to utilize NVIDIA GPU platform in the installation of PaddlePaddle, and the **CUDA toolkit** with proper version suitable for cuDNN. The cuDNN library below 6.0 is found to yield a fatal error in batch normalization when handling utterances with long duration in inference.
-
-### Setup
-
-```
+```bash
+git clone https://github.com/PaddlePaddle/models.git
+cd models/deep_speech_2
sh setup.sh
-export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH
```
-Please replace `$PADDLE_INSTALL_DIR` with your own paddle installation directory.
+## Getting Started
-## Usage
+Several shell scripts provided in `./examples` will help us to quickly give it a try, for most major modules, including data preparation, model training, case inference and model evaluation, with a few public dataset (e.g. [LibriSpeech](http://www.openslr.org/12/), [Aishell](http://www.openslr.org/33)). Reading these examples will also help you to understand how to make it work with your own data.
-### Preparing Data
+Some of the scripts in `./examples` are configured with 8 GPUs. If you don't have 8 GPUs available, please modify `CUDA_VISIBLE_DEVICES` and `--trainer_count`. If you don't have any GPU available, please set `--use_gpu` to False to use CPUs instead. Besides, if out-of-memory problem occurs, just reduce `--batch_size` to fit.
-```
-cd datasets
-sh run_all.sh
-cd ..
-```
+Let's take a tiny sampled subset of [LibriSpeech dataset](http://www.openslr.org/12/) for instance.
-`sh run_all.sh` prepares all ASR datasets (currently, only LibriSpeech available). After running, we have several summarization manifest files in json-format.
+- Go to directory
-A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcript text, audio duration) of each audio file within the data set, in json format. Manifest file serves as an interface informing our system of where and what to read the speech samples.
+ ```bash
+ cd examples/tiny
+ ```
+ Notice that this is only a toy example with a tiny sampled subset of LibriSpeech. If you would like to try with the complete dataset (would take several days for training), please go to `examples/librispeech` instead.
+- Prepare the data
-More help for arguments:
+ ```bash
+ sh run_data.sh
+ ```
-```
-python datasets/librispeech/librispeech.py --help
-```
+ `run_data.sh` will download dataset, generate manifests, collect normalizer's statistics and build vocabulary. Once the data preparation is done, you will find the data (only part of LibriSpeech) downloaded in `~/.cache/paddle/dataset/speech/libri` and the corresponding manifest files generated in `./data/tiny` as well as a mean stddev file and a vocabulary file. It has to be run for the very first time you run this dataset and is reusable for all further experiments.
+- Train your own ASR model
-### Preparing for Training
+ ```bash
+ sh run_train.sh
+ ```
-```
-python compute_mean_std.py
-```
+ `run_train.sh` will start a training job, with training logs printed to stdout and model checkpoint of every pass/epoch saved to `./checkpoints/tiny`. These checkpoints could be used for training resuming, inference, evaluation and deployment.
+- Case inference with an existing model
-It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by
+ ```bash
+ sh run_infer.sh
+ ```
-```
-python compute_mean_std.py --specgram_type mfcc
-```
+ `run_infer.sh` will show us some speech-to-text decoding results for several (default: 10) samples with the trained model. The performance might not be good now as the current model is only trained with a toy subset of LibriSpeech. To see the results with a better model, you can download a well-trained (trained for several days, with the complete LibriSpeech) model and do the inference:
-and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py.
+ ```bash
+ sh run_infer_golden.sh
+ ```
+- Evaluate an existing model
-More help for arguments:
+ ```bash
+ sh run_test.sh
+ ```
+
+ `run_test.sh` will evaluate the model with Word Error Rate (or Character Error Rate) measurement. Similarly, you can also download a well-trained model and test its performance:
+
+ ```bash
+ sh run_test_golden.sh
+ ```
+
+More detailed information are provided in the following sections. Wish you a happy journey with the *DeepSpeech2 on PaddlePaddle* ASR engine!
-```
-python compute_mean_std.py --help
-```
-### Training
+## Data Preparation
-For GPU Training:
+### Generate Manifest
+
+*DeepSpeech2 on PaddlePaddle* accepts a textual **manifest** file as its data set interface. A manifest file summarizes a set of speech data, with each line containing some meta data (e.g. filepath, transcription, duration) of one audio clip, in [JSON](http://www.json.org/) format, such as:
```
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py
+{"audio_filepath": "/home/work/.cache/paddle/Libri/134686/1089-134686-0001.flac", "duration": 3.275, "text": "stuff it into you his belly counselled him"}
+{"audio_filepath": "/home/work/.cache/paddle/Libri/134686/1089-134686-0007.flac", "duration": 4.275, "text": "a cold lucid indifference reigned in his soul"}
```
-For CPU Training:
+To use your custom data, you only need to generate such manifest files to summarize the dataset. Given such summarized manifests, training, inference and all other modules can be aware of where to access the audio files, as well as their meta data including the transcription labels.
+
+For how to generate such manifest files, please refer to `data/librispeech/librispeech.py`, which will download data and generate manifest files for LibriSpeech dataset.
+
+### Compute Mean & Stddev for Normalizer
+
+To perform z-score normalization (zero-mean, unit stddev) upon audio features, we have to estimate in advance the mean and standard deviation of the features, with some training samples:
+```bash
+python tools/compute_mean_std.py \
+--num_samples 2000 \
+--specgram_type linear \
+--manifest_paths data/librispeech/manifest.train \
+--output_path data/librispeech/mean_std.npz
```
-python train.py --use_gpu False
+
+It will compute the mean and standard deviation of power spectrum feature with 2000 random sampled audio clips listed in `data/librispeech/manifest.train` and save the results to `data/librispeech/mean_std.npz` for further usage.
+
+
+### Build Vocabulary
+
+A vocabulary of possible characters is required to convert the transcription into a list of token indices for training, and in decoding, to convert from a list of indices back to text again. Such a character-based vocabulary can be built with `tools/build_vocab.py`.
+
+```bash
+python tools/build_vocab.py \
+--count_threshold 0 \
+--vocab_path data/librispeech/eng_vocab.txt \
+--manifest_paths data/librispeech/manifest.train
```
-More help for arguments:
+It will write a vocabuary file `data/librispeeech/eng_vocab.txt` with all transcription text in `data/librispeech/manifest.train`, without vocabulary truncation (`--count_threshold 0`).
+
+### More Help
+For more help on arguments:
+
+```bash
+python data/librispeech/librispeech.py --help
+python tools/compute_mean_std.py --help
+python tools/build_vocab.py --help
```
+
+## Training a model
+
+`train.py` is the main caller of the training module. Examples of usage are shown below.
+
+- Start training from scratch with 8 GPUs:
+
+ ```
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --trainer_count 8
+ ```
+
+- Start training from scratch with 16 CPUs:
+
+ ```
+ python train.py --use_gpu False --trainer_count 16
+ ```
+- Resume training from a checkpoint:
+
+ ```
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ python train.py \
+ --init_model_path CHECKPOINT_PATH_TO_RESUME_FROM
+ ```
+
+For more help on arguments:
+
+```bash
python train.py --help
```
+or refer to `example/librispeech/run_train.sh`.
-### Preparing language model
+## Data Augmentation Pipeline
-The following steps, inference, parameters tuning and evaluating, will require a language model during decoding.
-A compressed language model is provided and can be accessed by
+Data augmentation has often been a highly effective technique to boost the deep learning performance. We augment our speech data by synthesizing new audios with small random perturbation (label-invariant transformation) added upon raw audios. You don't have to do the syntheses on your own, as it is already embedded into the data provider and is done on the fly, randomly for each epoch during training.
-```
-cd ./lm
-sh run.sh
-cd ..
-```
+Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.
-### Inference
+ - Volume Perturbation
+ - Speed Perturbation
+ - Shifting Perturbation
+ - Online Bayesian normalization
+ - Noise Perturbation (need background noise audio files)
+ - Impulse Response (need impulse audio files)
-For GPU inference
+In order to inform the trainer of what augmentation components are needed and what their processing orders are, it is required to prepare in advance a *augmentation configuration file* in [JSON](http://www.json.org/) format. For example:
```
-CUDA_VISIBLE_DEVICES=0 python infer.py
+[{
+ "type": "speed",
+ "params": {"min_speed_rate": 0.95,
+ "max_speed_rate": 1.05},
+ "prob": 0.6
+},
+{
+ "type": "shift",
+ "params": {"min_shift_ms": -5,
+ "max_shift_ms": 5},
+ "prob": 0.8
+}]
```
-For CPU inference
+When the `--augment_conf_file` argument of `trainer.py` is set to the path of the above example configuration file, every audio clip in every epoch will be processed: with 60% of chance, it will first be speed perturbed with a uniformly random sampled speed-rate between 0.95 and 1.05, and then with 80% of chance it will be shifted in time with a random sampled offset between -5 ms and 5 ms. Finally this newly synthesized audio clip will be feed into the feature extractor for further training.
+For other configuration examples, please refer to `conf/augmenatation.config.example`.
+
+Be careful when utilizing the data augmentation technique, as improper augmentation will do harm to the training, due to the enlarged train-test gap.
+
+## Inference and Evaluation
+
+### Prepare Language Model
+
+A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. Users can simply run this to download the preprared language models:
+
+```bash
+cd models/lm
+sh download_lm_en.sh
+sh download_lm_ch.sh
```
-python infer.py --use_gpu=False
-```
+If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials.
+
+TODO: any other requirements or tips to add?
+
+### Speech-to-text Inference
+
+An inference module caller `infer.py` is provided to infer, decode and visualize speech-to-text results for several given audio clips. It might help to have an intuitive and qualitative evaluation of the ASR model's performance.
-More help for arguments:
+- Inference with GPU:
+
+ ```bash
+ CUDA_VISIBLE_DEVICES=0 python infer.py --trainer_count 1
+ ```
+
+- Inference with CPUs:
+
+ ```bash
+ python infer.py --use_gpu False --trainer_count 12
+ ```
+
+We provide two types of CTC decoders: *CTC greedy decoder* and *CTC beam search decoder*. The *CTC greedy decoder* is an implementation of the simple best-path decoding algorithm, selecting at each timestep the most likely token, thus being greedy and locally optimal. The [*CTC beam search decoder*](https://arxiv.org/abs/1408.2873) otherwise utilizes a heuristic breadth-first graph search for reaching a near global optimality; it also requires a pre-trained KenLM language model for better scoring and ranking. The decoder type can be set with argument `--decoding_method`.
+
+For more help on arguments:
```
python infer.py --help
```
+or refer to `example/librispeech/run_infer.sh`.
-### Evaluating
+### Evaluate a Model
-```
-CUDA_VISIBLE_DEVICES=0 python evaluate.py
-```
+To evaluate a model's performance quantitatively, please run:
-More help for arguments:
+- Evaluation with GPUs:
-```
-python evaluate.py --help
-```
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python test.py --trainer_count 8
+ ```
-### Parameters tuning
+- Evaluation with CPUs:
-Usually, the parameters $\alpha$ and $\beta$ for the CTC [prefix beam search](https://arxiv.org/abs/1408.2873) decoder need to be tuned after retraining the acoustic model.
+ ```bash
+ python test.py --use_gpu False --trainer_count 12
+ ```
-For GPU tuning
+The error rate (default: word error rate; can be set with `--error_rate_type`) will be printed.
+For more help on arguments:
+
+```bash
+python test.py --help
```
-CUDA_VISIBLE_DEVICES=0 python tune.py
-```
+or refer to `example/librispeech/run_test.sh`.
-For CPU tuning
+## Hyper-parameters Tuning
-```
-python tune.py --use_gpu=False
-```
+The hyper-parameters $\alpha$ (language model weight) and $\beta$ (word insertion weight) for the [*CTC beam search decoder*](https://arxiv.org/abs/1408.2873) often have a significant impact on the decoder's performance. It would be better to re-tune them on the validation set when the acoustic model is renewed.
-More help for arguments:
+`tools/tune.py` performs a 2-D grid search over the hyper-parameter $\alpha$ and $\beta$. You must provide the range of $\alpha$ and $\beta$, as well as the number of their attempts.
-```
+- Tuning with GPU:
+
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ python tools/tune.py \
+ --trainer_count 8 \
+ --alpha_from 1.0 \
+ --alpha_to 3.2 \
+ --num_alphas 45 \
+ --beta_from 0.1 \
+ --beta_to 0.45 \
+ --num_betas 8
+ ```
+
+- Tuning with CPU:
+
+ ```bash
+ python tools/tune.py --use_gpu False
+ ```
+ The grid search will print the WER (word error rate) or CER (character error rate) at each point in the hyper-parameters space, and draw the error surface optionally. A proper hyper-parameters range should include the global minima of the error surface for WER/CER, as illustrated in the following figure.
+
+
+
+
An example error surface for tuning on the dev-clean set of LibriSpeech
+
+
+Usually, as the figure shows, the variation of language model weight ($\alpha$) significantly affect the performance of CTC beam search decoder. And a better procedure is to first tune on serveral data batches (the number can be specified) to find out the proper range of hyper-parameters, then change to the whole validation set to carray out an accurate tuning.
+
+After tuning, you can reset $\alpha$ and $\beta$ in the inference and evaluation modules to see if they really help improve the ASR performance. For more help
+
+```bash
python tune.py --help
```
+or refer to `example/librispeech/run_tune.sh`.
-Then reset parameters with the tuning result before inference or evaluating.
-### Playing with the ASR Demo
-A real-time ASR demo is built for users to try out the ASR model with their own voice. Please do the following installation on the machine you'd like to run the demo's client (no need for the machine running the demo's server).
+## Distributed Cloud Training
-For example, on MAC OS X:
+We also provide a cloud training module for users to do the distributed cluster training on [PaddleCloud](https://github.com/PaddlePaddle/cloud), to achieve a much faster training speed with multiple machines. To start with this, please first install PaddleCloud client and register a PaddleCloud account, as described in [PaddleCloud Usage](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#%E4%B8%8B%E8%BD%BD%E5%B9%B6%E9%85%8D%E7%BD%AEpaddlecloud).
+
+Please take the following steps to submit a training job:
+
+- Go to directory:
+
+ ```bash
+ cd cloud
+ ```
+- Upload data:
+
+ Data must be uploaded to PaddleCloud filesystem to be accessed within a cloud job. `pcloud_upload_data.sh` helps do the data packing and uploading:
+
+ ```bash
+ sh pcloud_upload_data.sh
+ ```
+
+ Given input manifests, `pcloud_upload_data.sh` will:
+
+ - Extract the audio files listed in the input manifests.
+ - Pack them into a specified number of tar files.
+ - Upload these tar files to PaddleCloud filesystem.
+ - Create cloud manifests by replacing local filesystem paths with PaddleCloud filesystem paths. New manifests will be used to inform the cloud jobs of audio files' location and their meta information.
+
+ It should be done only once for the very first time to do the cloud training. Later, the data is kept persisitent on the cloud filesystem and reusable for further job submissions.
+
+ For argument details please refer to [Train DeepSpeech2 on PaddleCloud](https://github.com/PaddlePaddle/models/tree/develop/deep_speech_2/cloud).
+
+ - Configure training arguments:
+
+ Configure the cloud job parameters in `pcloud_submit.sh` (e.g. `NUM_NODES`, `NUM_GPUS`, `CLOUD_TRAIN_DIR`, `JOB_NAME` etc.) and then configure other hyper-parameters for training in `pcloud_train.sh` (just as what you do for local training).
+
+ For argument details please refer to [Train DeepSpeech2 on PaddleCloud](https://github.com/PaddlePaddle/models/tree/develop/deep_speech_2/cloud).
+
+ - Submit the job:
+
+ By running:
+
+ ```bash
+ sh pcloud_submit.sh
+ ```
+ a training job has been submitted to PaddleCloud, with the job name printed to the console.
+
+ - Get training logs
+
+ Run this to list all the jobs you have submitted, as well as their running status:
+
+ ```bash
+ paddlecloud get jobs
+ ```
+
+ Run this, the corresponding job's logs will be printed.
+ ```bash
+ paddlecloud logs -n 10000 $REPLACED_WITH_YOUR_ACTUAL_JOB_NAME
+ ```
+
+For more information about the usage of PaddleCloud, please refer to [PaddleCloud Usage](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#提交任务).
+
+For more information about the DeepSpeech2 training on PaddleCloud, please refer to
+[Train DeepSpeech2 on PaddleCloud](https://github.com/PaddlePaddle/models/tree/develop/deep_speech_2/cloud).
+
+## Training for Mandarin Language
+
+TODO: to be added
+
+## Trying Live Demo with Your Own Voice
+Until now, an ASR model is trained and tested qualitatively (`infer.py`) and quantitatively (`test.py`) with existing audio files. But it is not yet tested with your own speech. `deploy/demo_server.py` and `deploy/demo_client.py` helps quickly build up a real-time demo ASR engine with the trained model, enabling you to test and play around with the demo, with your own voice.
+
+To start the demo's server, please run this in one console:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 \
+python deploy/demo_server.py \
+--trainer_count 1 \
+--host_ip localhost \
+--host_port 8086
```
+
+For the machine (might not be the same machine) to run the demo's client, please do the following installation before moving on.
+
+For example, on MAC OS X:
+
+```bash
brew install portaudio
pip install pyaudio
pip install pynput
```
-After a model and language model is prepared, we can first start the demo's server:
-```
-CUDA_VISIBLE_DEVICES=0 python demo_server.py
-```
-And then in another console, start the demo's client:
+Then to start the client, please run this in another console:
+```bash
+CUDA_VISIBLE_DEVICES=0 \
+python -u deploy/demo_client.py \
+--host_ip 'localhost' \
+--host_port 8086
```
-python demo_client.py
+
+Now, in the client console, press the `whitespace` key, hold, and start speaking. Until finishing your utterance, release the key to let the speech-to-text results shown in the console. To quit the client, just press `ESC` key.
+
+Notice that `deploy/demo_client.py` must be run on a machine with a microphone device, while `deploy/demo_server.py` could be run on one without any audio recording hardware, e.g. any remote server machine. Just be careful to set the `host_ip` and `host_port` argument with the actual accessible IP address and port, if the server and client are running with two separate machines. Nothing should be done if they are running on one single machine.
+
+Please also refer to `examples/mandarin/run_demo_server.sh`, which will first download a pre-trained Mandarin model (trained with 3000 hours of internal speech data) and then start the demo server with the model. With running `examples/mandarin/run_demo_client.sh`, you can speak Mandarin to test it. If you would like to try some other models, just update `--model_path` argument in the script.
+
+For more help on arguments:
+
+```bash
+python deploy/demo_server.py --help
+python deploy/demo_client.py --help
```
-On the client console, press and hold the "white-space" key on the keyboard to start talking, until you finish your speech and then release the "white-space" key. The decoding results (infered transcription) will be displayed.
-It could be possible to start the server and the client in two seperate machines, e.g. `demo_client.py` is usually started in a machine with a microphone hardware, while `demo_server.py` is usually started in a remote server with powerful GPUs. Please first make sure that these two machines have network access to each other, and then use `--host_ip` and `--host_port` to indicate the server machine's actual IP address (instead of the `localhost` as default) and TCP port, in both `demo_server.py` and `demo_client.py`.
+## Released Models
+
+#### Speech Model Released
+
+Language | Model Name | Training Data | Training Hours
+:-----------: | :------------: | :----------: | -------:
+English | [LibriSpeech Model](http://cloud.dlnel.org/filepub/?uuid=17404caf-cf19-492f-9707-1fad07c19aae) | [LibriSpeech Dataset](http://www.openslr.org/12/) | 960 h
+English | [Internal English Model](to-be-added) | Baidu English Dataset | 8628 h
+Mandarin | [Aishell Model](http://cloud.dlnel.org/filepub/?uuid=6c83b9d8-3255-4adf-9726-0fe0be3d0274) | [Aishell Dataset](http://www.openslr.org/33/) | 151 h
+Mandarin | [Internal Mandarin Model](to-be-added) | Baidu Mandarin Dataset | 2917 h
+
+#### Language Model Released
+
+Language Model | Training Data | Token-based | Size | Filter Configuraiton
+:-------------:| :------------:| :-----: | -----: | -----------------:
+[English LM](http://paddlepaddle.bj.bcebos.com/model_zoo/speech/common_crawl_00.prune01111.trie.klm) | To Be Added | Word-based | 8.3 GB | To Be Added
+[Mandarin LM](http://cloud.dlnel.org/filepub/?uuid=d21861e4-4ed6-45bb-ad8e-ae417a43195e) | To Be Added | Character-based | 2.8 GB | To Be Added
+
+## Experiments and Benchmarks
+
+#### English Model Evaluation (Word Error Rate)
+
+Test Set | LibriSpeech Model | Internal English Model
+:---------------------: | ---------------: | -------------------:
+LibriSpeech-Test-Clean | 7.96 | X.X
+LibriSpeech-Test-Other | 23.87 | X.X
+VoxForge-Test | X.X | X.X
+Baidu-English-Test | X.X | X.X
+
+(Beam size=2000)
+
+#### Mandarin Model Evaluation (Character Error Rate)
+
+Test Set | Aishell Model | Internal Mandarin Model
+:---------------------: | :---------------: | :-------------------:
+Aishell-Test | X.X | X.X
+Baidu-Mandarin-Test | X.X | X.X
+
+#### Acceleration with Multi-GPUs
+
+We compare the training time with 1, 2, 4, 8, 16 Tesla K40m GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars.
+
+
+
+| # of GPU | Acceleration Rate |
+| -------- | --------------: |
+| 1 | 1.00 X |
+| 2 | 1.97 X |
+| 4 | 3.74 X |
+| 8 | 6.21 X |
+|16 | 10.70 X |
+
+`tools/profile.sh` provides such a profiling tool.
+
+## Questions and Help
+
+You are welcome to submit questions and bug reports in [Github Issues](https://github.com/PaddlePaddle/models/issues). You are also welcome to contribute to this project.
diff --git a/cloud/README.md b/cloud/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a5be1c420880d4f32d472cdd23124cbf35033094
--- /dev/null
+++ b/cloud/README.md
@@ -0,0 +1,63 @@
+# Train DeepSpeech2 on PaddleCloud
+
+>Note:
+>Please make sure [PaddleCloud Client](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#%E4%B8%8B%E8%BD%BD%E5%B9%B6%E9%85%8D%E7%BD%AEpaddlecloud) has be installed and current directory is `deep_speech_2/cloud/`
+
+## Step 1: Upload Data
+
+Provided with several input manifests, `pcloud_upload_data.sh` will pack and upload all the containing audio files to PaddleCloud filesystem, and also generate some corresponding manifest files with updated cloud paths.
+
+Please modify the following arguments in `pcloud_upload_data.sh`:
+
+- `IN_MANIFESTS`: Paths (in local filesystem) of manifest files containing the audio files to be uploaded. Multiple paths can be concatenated with a whitespace delimeter.
+- `OUT_MANIFESTS`: Paths (in local filesystem) to write the updated output manifest files to. Multiple paths can be concatenated with a whitespace delimeter. The values of `audio_filepath` in the output manifests are updated with cloud filesystem paths.
+- `CLOUD_DATA_DIR`: Directory (in PaddleCloud filesystem) to upload the data to. Don't forget to replace `USERNAME` in the default directory and make sure that you have the permission to write it.
+- `NUM_SHARDS`: Number of data shards / parts (in tar files) to be generated when packing and uploading data. Smaller `num_shards` requires larger temoporal local disk space for packing data.
+
+By running:
+
+```
+sh pcloud_upload_data.sh
+```
+all the audio files will be uploaded to PaddleCloud filesystem, and you will get modified manifests files in `OUT_MANIFESTS`.
+
+You have to take this step only once, in the very first time you do the cloud training. Later on, the data is persisitent on the cloud filesystem and reusable for further job submissions.
+
+## Step 2: Configure Training
+
+Configure cloud training arguments in `pcloud_submit.sh`, with the following arguments:
+
+- `TRAIN_MANIFEST`: Manifest filepath (in local filesystem) for training. Notice that the`audio_filepath` should be in cloud filesystem, like those generated by `pcloud_upload_data.sh`.
+- `DEV_MANIFEST`: Manifest filepath (in local filesystem) for validation.
+- `CLOUD_MODEL_DIR`: Directory (in PaddleCloud filesystem) to save the model parameters (checkpoints). Don't forget to replace `USERNAME` in the default directory and make sure that you have the permission to write it.
+- `BATCH_SIZE`: Training batch size for a single node.
+- `NUM_GPU`: Number of GPUs allocated for a single node.
+- `NUM_NODE`: Number of nodes (machines) allocated for this job.
+- `IS_LOCAL`: Set to False to enable parameter server, if using multiple nodes.
+
+Configure other training hyper-parameters in `pcloud_train.sh` as you wish, just as what you can do in local training.
+
+By running:
+
+```
+sh pcloud_submit.sh
+```
+you submit a training job to PaddleCloud. And you will see the job name when the submission is done.
+
+
+## Step 3 Get Job Logs
+
+Run this to list all the jobs you have submitted, as well as their running status:
+
+```
+paddlecloud get jobs
+```
+
+Run this, the corresponding job's logs will be printed.
+```
+paddlecloud logs -n 10000 $REPLACED_WITH_YOUR_ACTUAL_JOB_NAME
+```
+
+## More Help
+
+For more information about the usage of PaddleCloud, please refer to [PaddleCloud Usage](https://github.com/PaddlePaddle/cloud/blob/develop/doc/usage_cn.md#提交任务).
diff --git a/cloud/_init_paths.py b/cloud/_init_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..3305d7488ff1cfb03db7175a53f70c1a107fe52e
--- /dev/null
+++ b/cloud/_init_paths.py
@@ -0,0 +1,17 @@
+"""Set up paths for DS2"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import sys
+
+
+def add_path(path):
+ if path not in sys.path:
+ sys.path.insert(0, path)
+
+
+this_dir = os.path.dirname(__file__)
+proj_path = os.path.join(this_dir, '..')
+add_path(proj_path)
diff --git a/cloud/pcloud_submit.sh b/cloud/pcloud_submit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..99e458db96b819019628a26f05b3597ea951aeea
--- /dev/null
+++ b/cloud/pcloud_submit.sh
@@ -0,0 +1,29 @@
+#! /usr/bin/env bash
+
+TRAIN_MANIFEST="cloud/cloud_manifests/cloud.manifest.train"
+DEV_MANIFEST="cloud/cloud_manifests/cloud.manifest.dev"
+CLOUD_MODEL_DIR="./checkpoints"
+BATCH_SIZE=512
+NUM_GPU=8
+NUM_NODE=1
+IS_LOCAL="True"
+
+JOB_NAME=deepspeech-`date +%Y%m%d%H%M%S`
+DS2_PATH=${PWD%/*}
+cp -f pcloud_train.sh ${DS2_PATH}
+
+paddlecloud submit \
+-image bootstrapper:5000/paddlepaddle/pcloud_ds2:latest \
+-jobname ${JOB_NAME} \
+-cpu ${NUM_GPU} \
+-gpu ${NUM_GPU} \
+-memory 64Gi \
+-parallelism ${NUM_NODE} \
+-pscpu 1 \
+-pservers 1 \
+-psmemory 64Gi \
+-passes 1 \
+-entry "sh pcloud_train.sh ${TRAIN_MANIFEST} ${DEV_MANIFEST} ${CLOUD_MODEL_DIR} ${NUM_GPU} ${BATCH_SIZE} ${IS_LOCAL}" \
+${DS2_PATH}
+
+rm ${DS2_PATH}/pcloud_train.sh
diff --git a/cloud/pcloud_train.sh b/cloud/pcloud_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d0c47dece91c43d0cbfde1f6eb2dcc96fce36391
--- /dev/null
+++ b/cloud/pcloud_train.sh
@@ -0,0 +1,46 @@
+#! /usr/bin/env bash
+
+TRAIN_MANIFEST=$1
+DEV_MANIFEST=$2
+MODEL_PATH=$3
+NUM_GPU=$4
+BATCH_SIZE=$5
+IS_LOCAL=$6
+
+python ./cloud/split_data.py \
+--in_manifest_path=${TRAIN_MANIFEST} \
+--out_manifest_path='/local.manifest.train'
+
+python ./cloud/split_data.py \
+--in_manifest_path=${DEV_MANIFEST} \
+--out_manifest_path='/local.manifest.dev'
+
+mkdir ./logs
+
+python -u train.py \
+--batch_size=${BATCH_SIZE} \
+--trainer_count=${NUM_GPU} \
+--num_passes=200 \
+--num_proc_data=${NUM_GPU} \
+--num_conv_layers=2 \
+--num_rnn_layers=3 \
+--rnn_layer_size=2048 \
+--num_iter_print=100 \
+--learning_rate=5e-4 \
+--max_duration=27.0 \
+--min_duration=0.0 \
+--use_sortagrad=True \
+--use_gru=False \
+--use_gpu=True \
+--is_local=${IS_LOCAL} \
+--share_rnn_weights=True \
+--train_manifest='/local.manifest.train' \
+--dev_manifest='/local.manifest.dev' \
+--mean_std_path='data/librispeech/mean_std.npz' \
+--vocab_path='data/librispeech/vocab.txt' \
+--output_model_dir='./checkpoints' \
+--output_model_dir=${MODEL_PATH} \
+--augment_conf_path='conf/augmentation.config' \
+--specgram_type='linear' \
+--shuffle_method='batch_shuffle_clipped' \
+2>&1 | tee ./logs/train.log
diff --git a/cloud/pcloud_upload_data.sh b/cloud/pcloud_upload_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..71bb4af19b3b30f6efc31cb9b60f4f3b330b46b9
--- /dev/null
+++ b/cloud/pcloud_upload_data.sh
@@ -0,0 +1,22 @@
+#! /usr/bin/env bash
+
+mkdir cloud_manifests
+
+IN_MANIFESTS="../data/librispeech/manifest.train ../data/librispeech/manifest.dev-clean ../data/librispeech/manifest.test-clean"
+OUT_MANIFESTS="cloud_manifests/cloud.manifest.train cloud_manifests/cloud.manifest.dev cloud_manifests/cloud.manifest.test"
+CLOUD_DATA_DIR="/pfs/dlnel/home/USERNAME/deepspeech2/data/librispeech"
+NUM_SHARDS=50
+
+python upload_data.py \
+--in_manifest_paths ${IN_MANIFESTS} \
+--out_manifest_paths ${OUT_MANIFESTS} \
+--cloud_data_dir ${CLOUD_DATA_DIR} \
+--num_shards ${NUM_SHARDS}
+
+if [ $? -ne 0 ]
+then
+ echo "Upload Data Failed!"
+ exit 1
+fi
+
+echo "All Done."
diff --git a/cloud/split_data.py b/cloud/split_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3496d52bfb5bf6c249c03dfb4df2937625bd55b5
--- /dev/null
+++ b/cloud/split_data.py
@@ -0,0 +1,41 @@
+"""This tool is used for splitting data into each node of
+paddlecloud. This script should be called in paddlecloud.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import json
+import argparse
+
+parser = argparse.ArgumentParser(description=__doc__)
+parser.add_argument(
+ "--in_manifest_path",
+ type=str,
+ required=True,
+ help="Input manifest path for all nodes.")
+parser.add_argument(
+ "--out_manifest_path",
+ type=str,
+ required=True,
+ help="Output manifest file path for current node.")
+args = parser.parse_args()
+
+
+def split_data(in_manifest_path, out_manifest_path):
+ with open("/trainer_id", "r") as f:
+ trainer_id = int(f.readline()[:-1])
+ with open("/trainer_count", "r") as f:
+ trainer_count = int(f.readline()[:-1])
+
+ out_manifest = []
+ for index, json_line in enumerate(open(in_manifest_path, 'r')):
+ if (index % trainer_count) == trainer_id:
+ out_manifest.append("%s\n" % json_line.strip())
+ with open(out_manifest_path, 'w') as f:
+ f.writelines(out_manifest)
+
+
+if __name__ == '__main__':
+ split_data(args.in_manifest_path, args.out_manifest_path)
diff --git a/cloud/upload_data.py b/cloud/upload_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..9973f8c768410fd86a6ded6a74dac24f9f918173
--- /dev/null
+++ b/cloud/upload_data.py
@@ -0,0 +1,129 @@
+"""This script is for uploading data for DeepSpeech2 training on paddlecloud.
+
+Steps:
+1. Read original manifests and extract local sound files.
+2. Tar all local sound files into multiple tar files and upload them.
+3. Modify original manifests with updated paths in cloud filesystem.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import tarfile
+import sys
+import argparse
+import shutil
+from subprocess import call
+import _init_paths
+from data_utils.utils import read_manifest
+
+parser = argparse.ArgumentParser(description=__doc__)
+parser.add_argument(
+ "--in_manifest_paths",
+ default=[
+ "../datasets/manifest.train", "../datasets/manifest.dev",
+ "../datasets/manifest.test"
+ ],
+ type=str,
+ nargs='+',
+ help="Local filepaths of input manifests to load, pack and upload."
+ "(default: %(default)s)")
+parser.add_argument(
+ "--out_manifest_paths",
+ default=[
+ "./cloud.manifest.train", "./cloud.manifest.dev",
+ "./cloud.manifest.test"
+ ],
+ type=str,
+ nargs='+',
+ help="Local filepaths of modified manifests to write to. "
+ "(default: %(default)s)")
+parser.add_argument(
+ "--cloud_data_dir",
+ required=True,
+ type=str,
+ help="Destination directory on paddlecloud to upload data to.")
+parser.add_argument(
+ "--num_shards",
+ default=10,
+ type=int,
+ help="Number of parts to split data to. (default: %(default)s)")
+parser.add_argument(
+ "--local_tmp_dir",
+ default="./tmp/",
+ type=str,
+ help="Local directory for storing temporary data. (default: %(default)s)")
+args = parser.parse_args()
+
+
+def upload_data(in_manifest_path_list, out_manifest_path_list, local_tmp_dir,
+ upload_tar_dir, num_shards):
+ """Extract and pack sound files listed in the manifest files into multple
+ tar files and upload them to padldecloud. Besides, generate new manifest
+ files with updated paths in paddlecloud.
+ """
+ # compute total audio number
+ total_line = 0
+ for manifest_path in in_manifest_path_list:
+ with open(manifest_path, 'r') as f:
+ total_line += len(f.readlines())
+ line_per_tar = (total_line // num_shards) + 1
+
+ # pack and upload shard by shard
+ line_count, tar_file = 0, None
+ for manifest_path, out_manifest_path in zip(in_manifest_path_list,
+ out_manifest_path_list):
+ manifest = read_manifest(manifest_path)
+ out_manifest = []
+ for json_data in manifest:
+ sound_filepath = json_data['audio_filepath']
+ sound_filename = os.path.basename(sound_filepath)
+ if line_count % line_per_tar == 0:
+ if tar_file != None:
+ tar_file.close()
+ pcloud_cp(tar_path, upload_tar_dir)
+ os.remove(tar_path)
+ tar_name = 'part-%s-of-%s.tar' % (
+ str(line_count // line_per_tar).zfill(5),
+ str(num_shards).zfill(5))
+ tar_path = os.path.join(local_tmp_dir, tar_name)
+ tar_file = tarfile.open(tar_path, 'w')
+ tar_file.add(sound_filepath, arcname=sound_filename)
+ line_count += 1
+ json_data['audio_filepath'] = "tar:%s#%s" % (
+ os.path.join(upload_tar_dir, tar_name), sound_filename)
+ out_manifest.append("%s\n" % json.dumps(json_data))
+ with open(out_manifest_path, 'w') as f:
+ f.writelines(out_manifest)
+ pcloud_cp(out_manifest_path, upload_tar_dir)
+ tar_file.close()
+ pcloud_cp(tar_path, upload_tar_dir)
+ os.remove(tar_path)
+
+
+def pcloud_mkdir(dir):
+ """Make directory in PaddleCloud filesystem.
+ """
+ if call(['paddlecloud', 'mkdir', dir]) != 0:
+ raise IOError("PaddleCloud mkdir failed: %s." % dir)
+
+
+def pcloud_cp(src, dst):
+ """Copy src from local filesytem to dst in PaddleCloud filesystem,
+ or downlowd src from PaddleCloud filesystem to dst in local filesystem.
+ """
+ if call(['paddlecloud', 'cp', src, dst]) != 0:
+ raise IOError("PaddleCloud cp failed: from [%s] to [%s]." % (src, dst))
+
+
+if __name__ == '__main__':
+ if not os.path.exists(args.local_tmp_dir):
+ os.makedirs(args.local_tmp_dir)
+ pcloud_mkdir(args.cloud_data_dir)
+
+ upload_data(args.in_manifest_paths, args.out_manifest_paths,
+ args.local_tmp_dir, args.cloud_data_dir, args.num_shards)
+
+ shutil.rmtree(args.local_tmp_dir)
diff --git a/compute_mean_std.py b/compute_mean_std.py
deleted file mode 100644
index 0cc84e73022ecb1333b805457cace39adcc68ce4..0000000000000000000000000000000000000000
--- a/compute_mean_std.py
+++ /dev/null
@@ -1,63 +0,0 @@
-"""Compute mean and std for feature normalizer, and save to file."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-from data_utils.normalizer import FeatureNormalizer
-from data_utils.augmentor.augmentation import AugmentationPipeline
-from data_utils.featurizer.audio_featurizer import AudioFeaturizer
-
-parser = argparse.ArgumentParser(
- description='Computing mean and stddev for feature normalizer.')
-parser.add_argument(
- "--specgram_type",
- default='linear',
- type=str,
- help="Feature type of audio data: 'linear' (power spectrum)"
- " or 'mfcc'. (default: %(default)s)")
-parser.add_argument(
- "--manifest_path",
- default='datasets/manifest.train',
- type=str,
- help="Manifest path for computing normalizer's mean and stddev."
- "(default: %(default)s)")
-parser.add_argument(
- "--num_samples",
- default=2000,
- type=int,
- help="Number of samples for computing mean and stddev. "
- "(default: %(default)s)")
-parser.add_argument(
- "--augmentation_config",
- default='{}',
- type=str,
- help="Augmentation configuration in json-format. "
- "(default: %(default)s)")
-parser.add_argument(
- "--output_file",
- default='mean_std.npz',
- type=str,
- help="Filepath to write mean and std to (.npz)."
- "(default: %(default)s)")
-args = parser.parse_args()
-
-
-def main():
- augmentation_pipeline = AugmentationPipeline(args.augmentation_config)
- audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type)
-
- def augment_and_featurize(audio_segment):
- augmentation_pipeline.transform_audio(audio_segment)
- return audio_featurizer.featurize(audio_segment)
-
- normalizer = FeatureNormalizer(
- mean_std_filepath=None,
- manifest_path=args.manifest_path,
- featurize_func=augment_and_featurize,
- num_samples=args.num_samples)
- normalizer.write_to_file(args.output_file)
-
-
-if __name__ == '__main__':
- main()
diff --git a/data/aishell/aishell.py b/data/aishell/aishell.py
new file mode 100644
index 0000000000000000000000000000000000000000..17786b5d42d19fd1300c142b494d78f56e9f26dd
--- /dev/null
+++ b/data/aishell/aishell.py
@@ -0,0 +1,109 @@
+"""Prepare Aishell mandarin dataset
+
+Download, unpack and create manifest files.
+Manifest file is a json-format file with each line containing the
+meta data (i.e. audio filepath, transcript and audio duration)
+of each audio file in the data set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import codecs
+import soundfile
+import json
+import argparse
+from data_utils.utility import download, unpack
+
+DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
+
+URL_ROOT = 'http://www.openslr.org/resources/33'
+DATA_URL = URL_ROOT + '/data_aishell.tgz'
+MD5_DATA = '2f494334227864a8a8fec932999db9d8'
+
+parser = argparse.ArgumentParser(description=__doc__)
+parser.add_argument(
+ "--target_dir",
+ default=DATA_HOME + "/Aishell",
+ type=str,
+ help="Directory to save the dataset. (default: %(default)s)")
+parser.add_argument(
+ "--manifest_prefix",
+ default="manifest",
+ type=str,
+ help="Filepath prefix for output manifests. (default: %(default)s)")
+args = parser.parse_args()
+
+
+def create_manifest(data_dir, manifest_path_prefix):
+ print("Creating manifest %s ..." % manifest_path_prefix)
+ json_lines = []
+ transcript_path = os.path.join(data_dir, 'transcript',
+ 'aishell_transcript_v0.8.txt')
+ transcript_dict = {}
+ for line in codecs.open(transcript_path, 'r', 'utf-8'):
+ line = line.strip()
+ if line == '': continue
+ audio_id, text = line.split(' ', 1)
+ # remove withespace
+ text = ''.join(text.split())
+ transcript_dict[audio_id] = text
+
+ data_types = ['train', 'dev', 'test']
+ for type in data_types:
+ audio_dir = os.path.join(data_dir, 'wav', type)
+ for subfolder, _, filelist in sorted(os.walk(audio_dir)):
+ for fname in filelist:
+ audio_path = os.path.join(subfolder, fname)
+ audio_id = fname[:-4]
+ # if no transcription for audio then skipped
+ if audio_id not in transcript_dict:
+ continue
+ audio_data, samplerate = soundfile.read(audio_path)
+ duration = float(len(audio_data) / samplerate)
+ text = transcript_dict[audio_id]
+ json_lines.append(
+ json.dumps(
+ {
+ 'audio_filepath': audio_path,
+ 'duration': duration,
+ 'text': text
+ },
+ ensure_ascii=False))
+ manifest_path = manifest_path_prefix + '.' + type
+ with codecs.open(manifest_path, 'w', 'utf-8') as fout:
+ for line in json_lines:
+ fout.write(line + '\n')
+
+
+def prepare_dataset(url, md5sum, target_dir, manifest_path):
+ """Download, unpack and create manifest file."""
+ data_dir = os.path.join(target_dir, 'data_aishell')
+ if not os.path.exists(data_dir):
+ filepath = download(url, md5sum, target_dir)
+ unpack(filepath, target_dir)
+ # unpack all audio tar files
+ audio_dir = os.path.join(data_dir, 'wav')
+ for subfolder, _, filelist in sorted(os.walk(audio_dir)):
+ for ftar in filelist:
+ unpack(os.path.join(subfolder, ftar), subfolder, True)
+ else:
+ print("Skip downloading and unpacking. Data already exists in %s." %
+ target_dir)
+ create_manifest(data_dir, manifest_path)
+
+
+def main():
+ if args.target_dir.startswith('~'):
+ args.target_dir = os.path.expanduser(args.target_dir)
+
+ prepare_dataset(
+ url=DATA_URL,
+ md5sum=MD5_DATA,
+ target_dir=args.target_dir,
+ manifest_path=args.manifest_prefix)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/datasets/librispeech/librispeech.py b/data/librispeech/librispeech.py
similarity index 77%
rename from datasets/librispeech/librispeech.py
rename to data/librispeech/librispeech.py
index 7e941f0ea7f260680f60dc706fd9873532e3c8bb..9a8e1c2871f74823b04c5839dd43f08f9a03d1df 100644
--- a/datasets/librispeech/librispeech.py
+++ b/data/librispeech/librispeech.py
@@ -12,13 +12,11 @@ from __future__ import print_function
import distutils.util
import os
import sys
-import tarfile
import argparse
import soundfile
import json
-from paddle.v2.dataset.common import md5file
-
-DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
+import codecs
+from data_utils.utility import download, unpack
URL_ROOT = "http://www.openslr.org/resources/12"
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
@@ -40,7 +38,7 @@ MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
- default=DATA_HOME + "/Libri",
+ default='~/.cache/paddle/dataset/speech/libri',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
@@ -58,36 +56,8 @@ parser.add_argument(
args = parser.parse_args()
-def download(url, md5sum, target_dir):
- """
- Download file from url to target_dir, and check md5sum.
- """
- if not os.path.exists(target_dir): os.makedirs(target_dir)
- filepath = os.path.join(target_dir, url.split("/")[-1])
- if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
- print("Downloading %s ..." % url)
- os.system("wget -c " + url + " -P " + target_dir)
- print("\nMD5 Chesksum %s ..." % filepath)
- if not md5file(filepath) == md5sum:
- raise RuntimeError("MD5 checksum failed.")
- else:
- print("File exists, skip downloading. (%s)" % filepath)
- return filepath
-
-
-def unpack(filepath, target_dir):
- """
- Unpack the file to the target_dir.
- """
- print("Unpacking %s ..." % filepath)
- tar = tarfile.open(filepath)
- tar.extractall(target_dir)
- tar.close()
-
-
def create_manifest(data_dir, manifest_path):
- """
- Create a manifest json file summarizing the data set, with each line
+ """Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
@@ -112,14 +82,13 @@ def create_manifest(data_dir, manifest_path):
'duration': duration,
'text': text
}))
- with open(manifest_path, 'w') as out_file:
+ with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
- """
- Download, unpack and create summmary manifest file.
+ """Download, unpack and create summmary manifest file.
"""
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
# download
@@ -134,6 +103,9 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path):
def main():
+ if args.target_dir.startswith('~'):
+ args.target_dir = os.path.expanduser(args.target_dir)
+
prepare_dataset(
url=URL_TEST_CLEAN,
md5sum=MD5_TEST_CLEAN,
@@ -144,12 +116,12 @@ def main():
md5sum=MD5_DEV_CLEAN,
target_dir=os.path.join(args.target_dir, "dev-clean"),
manifest_path=args.manifest_prefix + ".dev-clean")
- prepare_dataset(
- url=URL_TRAIN_CLEAN_100,
- md5sum=MD5_TRAIN_CLEAN_100,
- target_dir=os.path.join(args.target_dir, "train-clean-100"),
- manifest_path=args.manifest_prefix + ".train-clean-100")
if args.full_download:
+ prepare_dataset(
+ url=URL_TRAIN_CLEAN_100,
+ md5sum=MD5_TRAIN_CLEAN_100,
+ target_dir=os.path.join(args.target_dir, "train-clean-100"),
+ manifest_path=args.manifest_prefix + ".train-clean-100")
prepare_dataset(
url=URL_TEST_OTHER,
md5sum=MD5_TEST_OTHER,
diff --git a/datasets/noise/chime3_background.py b/data/noise/chime3_background.py
similarity index 100%
rename from datasets/noise/chime3_background.py
rename to data/noise/chime3_background.py
diff --git a/data_utils/augmentor/impulse_response.py b/data_utils/augmentor/impulse_response.py
index c3de0fdbb2a40150f8cffdef3487ceb4400e52ed..536b4d6a4a6666359b90e191a3d593250b44e863 100644
--- a/data_utils/augmentor/impulse_response.py
+++ b/data_utils/augmentor/impulse_response.py
@@ -4,23 +4,22 @@ from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
-from data_utils import utils
+from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class ImpulseResponseAugmentor(AugmentorBase):
"""Augmentation model for adding impulse response effect.
-
+
:param rng: Random generator object.
:type rng: random.Random
:param impulse_manifest_path: Manifest path for impulse audio data.
- :type impulse_manifest_path: basestring
+ :type impulse_manifest_path: basestring
"""
def __init__(self, rng, impulse_manifest_path):
self._rng = rng
- self._impulse_manifest = utils.read_manifest(
- manifest_path=impulse_manifest_path)
+ self._impulse_manifest = read_manifest(impulse_manifest_path)
def transform_audio(self, audio_segment):
"""Add impulse response effect.
diff --git a/data_utils/augmentor/noise_perturb.py b/data_utils/augmentor/noise_perturb.py
index 281174af42c2f6d673ead94bd532941769c79c25..96e0ff4deac48063faf76338014e418e3d8ad4ad 100644
--- a/data_utils/augmentor/noise_perturb.py
+++ b/data_utils/augmentor/noise_perturb.py
@@ -4,13 +4,13 @@ from __future__ import division
from __future__ import print_function
from data_utils.augmentor.base import AugmentorBase
-from data_utils import utils
+from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase):
"""Augmentation model for adding background noise.
-
+
:param rng: Random generator object.
:type rng: random.Random
:param min_snr_dB: Minimal signal noise ratio, in decibels.
@@ -18,15 +18,14 @@ class NoisePerturbAugmentor(AugmentorBase):
:param max_snr_dB: Maximal signal noise ratio, in decibels.
:type max_snr_dB: float
:param noise_manifest_path: Manifest path for noise audio data.
- :type noise_manifest_path: basestring
+ :type noise_manifest_path: basestring
"""
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
self._min_snr_dB = min_snr_dB
self._max_snr_dB = max_snr_dB
self._rng = rng
- self._noise_manifest = utils.read_manifest(
- manifest_path=noise_manifest_path)
+ self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def transform_audio(self, audio_segment):
"""Add background noise audio.
diff --git a/data_utils/data.py b/data_utils/data.py
index 159bf69d582d6418f01ecbea01d716ac4a279207..8bff6826dc51d6caaa420bec5a886e1878f36df4 100644
--- a/data_utils/data.py
+++ b/data_utils/data.py
@@ -6,10 +6,12 @@ from __future__ import division
from __future__ import print_function
import random
-import numpy as np
+import tarfile
import multiprocessing
+import numpy as np
import paddle.v2 as paddle
-from data_utils import utils
+from threading import local
+from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment
@@ -46,7 +48,7 @@ class DataGenerator(object):
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
- before extracting the features.
+ before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
@@ -82,16 +84,20 @@ class DataGenerator(object):
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._epoch = 0
+ # for caching tar files info
+ self._local_data = local()
+ self._local_data.tar2info = {}
+ self._local_data.tar2object = {}
def process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data.
:param filename: Audio filepath
- :type filename: basestring
+ :type filename: basestring | file
:param transcript: Transcription text.
:type transcript: basestring
:return: Tuple of audio feature tensor and list of token ids for
- transcription.
+ transcription.
:rtype: tuple of (2darray, list)
"""
speech_segment = SpeechSegment.from_file(filename, transcript)
@@ -111,7 +117,7 @@ class DataGenerator(object):
"""
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
-
+
Audio features within one batch will be padded with zeros to have the
same shape, or a user-defined shape.
@@ -153,7 +159,7 @@ class DataGenerator(object):
def batch_reader():
# read manifest
- manifest = utils.read_manifest(
+ manifest = read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
@@ -191,9 +197,9 @@ class DataGenerator(object):
@property
def feeding(self):
"""Returns data reader's feeding dict.
-
+
:return: Data feeding dict.
- :rtype: dict
+ :rtype: dict
"""
return {"audio_spectrogram": 0, "transcript_text": 1}
@@ -215,6 +221,38 @@ class DataGenerator(object):
"""
return self._speech_featurizer.vocab_list
+ def _parse_tar(self, file):
+ """Parse a tar file to get a tarfile object
+ and a map containing tarinfoes
+ """
+ result = {}
+ f = tarfile.open(file)
+ for tarinfo in f.getmembers():
+ result[tarinfo.name] = tarinfo
+ return f, result
+
+ def _get_file_object(self, file):
+ """Get file object by file path.
+
+ If file startwith tar, it will return a tar file object
+ and cached tar file info for next reading request.
+ It will return file directly, if the type of file is not str.
+ """
+ if file.startswith('tar:'):
+ tarpath, filename = file.split(':', 1)[1].split('#', 1)
+ if 'tar2info' not in self._local_data.__dict__:
+ self._local_data.tar2info = {}
+ if 'tar2object' not in self._local_data.__dict__:
+ self._local_data.tar2object = {}
+ if tarpath not in self._local_data.tar2info:
+ object, infoes = self._parse_tar(tarpath)
+ self._local_data.tar2info[tarpath] = infoes
+ self._local_data.tar2object[tarpath] = object
+ return self._local_data.tar2object[tarpath].extractfile(
+ self._local_data.tar2info[tarpath][filename])
+ else:
+ return open(file, 'r')
+
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
@@ -229,8 +267,9 @@ class DataGenerator(object):
yield instance
def mapper(instance):
- return self.process_utterance(instance["audio_filepath"],
- instance["text"])
+ return self.process_utterance(
+ self._get_file_object(instance["audio_filepath"]),
+ instance["text"])
return paddle.reader.xmap_readers(
mapper, reader, self._num_threads, 1024, order=True)
diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py
index 00f0e8a35bc8e67ab285b7d509a0992c02dc54ca..12f8784a9921e9bd78735db3edda3898c54ee908 100644
--- a/data_utils/featurizer/audio_featurizer.py
+++ b/data_utils/featurizer/audio_featurizer.py
@@ -4,7 +4,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from data_utils import utils
+from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
from python_speech_features import mfcc
from python_speech_features import delta
@@ -57,7 +57,7 @@ class AudioFeaturizer(object):
def featurize(self,
audio_segment,
allow_downsampling=True,
- allow_upsamplling=True):
+ allow_upsampling=True):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
@@ -159,24 +159,27 @@ class AudioFeaturizer(object):
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
- raise ValueError("max_freq must be greater than half of "
+ raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
- # compute 13 cepstral coefficients, and the first one is replaced
+ # compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy)
- mfcc_feat = np.transpose(
- mfcc(
- signal=samples,
- samplerate=sample_rate,
- winlen=0.001 * window_ms,
- winstep=0.001 * stride_ms,
- highfreq=max_freq))
+ mfcc_feat = mfcc(
+ signal=samples,
+ samplerate=sample_rate,
+ winlen=0.001 * window_ms,
+ winstep=0.001 * stride_ms,
+ highfreq=max_freq)
# Deltas
d_mfcc_feat = delta(mfcc_feat, 2)
# Deltas-Deltas
dd_mfcc_feat = delta(d_mfcc_feat, 2)
+ # transpose
+ mfcc_feat = np.transpose(mfcc_feat)
+ d_mfcc_feat = np.transpose(d_mfcc_feat)
+ dd_mfcc_feat = np.transpose(dd_mfcc_feat)
# concat above three features
concat_mfcc_feat = np.concatenate(
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py
index 4f9a49b594010f91a64797b9a4b7e9054d4749d5..89202163ca8d8b69f59b858db5451882d7e089b3 100644
--- a/data_utils/featurizer/text_featurizer.py
+++ b/data_utils/featurizer/text_featurizer.py
@@ -4,6 +4,7 @@ from __future__ import division
from __future__ import print_function
import os
+import codecs
class TextFeaturizer(object):
@@ -59,7 +60,7 @@ class TextFeaturizer(object):
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
- with open(vocab_filepath, 'r') as file:
+ with codecs.open(vocab_filepath, 'r', 'utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py
index 1f4aae9a0913f323480c46c2d449f9515a65bb7e..7c2e05c9d85fa55c0a91386ebf9ba570b2ec0e3b 100644
--- a/data_utils/normalizer.py
+++ b/data_utils/normalizer.py
@@ -5,7 +5,7 @@ from __future__ import print_function
import numpy as np
import random
-import data_utils.utils as utils
+from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
@@ -75,7 +75,7 @@ class FeatureNormalizer(object):
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
"""Compute mean and std from randomly sampled instances."""
- manifest = utils.read_manifest(manifest_path)
+ manifest = read_manifest(manifest_path)
sampled_manifest = self._rng.sample(manifest, num_samples)
features = []
for instance in sampled_manifest:
diff --git a/data_utils/utility.py b/data_utils/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7b66ef2f65699678c09def05ee95fe5c52c52f
--- /dev/null
+++ b/data_utils/utility.py
@@ -0,0 +1,63 @@
+"""Contains data helper functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import codecs
+import os
+import tarfile
+from paddle.v2.dataset.common import md5file
+
+
+def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
+ """Load and parse manifest file.
+
+ Instances with durations outside [min_duration, max_duration] will be
+ filtered out.
+
+ :param manifest_path: Manifest file to load and parse.
+ :type manifest_path: basestring
+ :param max_duration: Maximal duration in seconds for instance filter.
+ :type max_duration: float
+ :param min_duration: Minimal duration in seconds for instance filter.
+ :type min_duration: float
+ :return: Manifest parsing results. List of dict.
+ :rtype: list
+ :raises IOError: If failed to parse the manifest.
+ """
+ manifest = []
+ for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
+ try:
+ json_data = json.loads(json_line)
+ except Exception as e:
+ raise IOError("Error reading manifest: %s" % str(e))
+ if (json_data["duration"] <= max_duration and
+ json_data["duration"] >= min_duration):
+ manifest.append(json_data)
+ return manifest
+
+
+def download(url, md5sum, target_dir):
+ """Download file from url to target_dir, and check md5sum."""
+ if not os.path.exists(target_dir): os.makedirs(target_dir)
+ filepath = os.path.join(target_dir, url.split("/")[-1])
+ if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
+ print("Downloading %s ..." % url)
+ os.system("wget -c " + url + " -P " + target_dir)
+ print("\nMD5 Chesksum %s ..." % filepath)
+ if not md5file(filepath) == md5sum:
+ raise RuntimeError("MD5 checksum failed.")
+ else:
+ print("File exists, skip downloading. (%s)" % filepath)
+ return filepath
+
+
+def unpack(filepath, target_dir, rm_tar=False):
+ """Unpack the file to the target_dir."""
+ print("Unpacking %s ..." % filepath)
+ tar = tarfile.open(filepath)
+ tar.extractall(target_dir)
+ tar.close()
+ if rm_tar == True:
+ os.remove(filepath)
diff --git a/data_utils/utils.py b/data_utils/utils.py
deleted file mode 100644
index 3f1165718aa0e2a0bf0687b8a613a6447b964ee8..0000000000000000000000000000000000000000
--- a/data_utils/utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Contains data helper functions."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import json
-
-
-def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
- """Load and parse manifest file.
-
- Instances with durations outside [min_duration, max_duration] will be
- filtered out.
-
- :param manifest_path: Manifest file to load and parse.
- :type manifest_path: basestring
- :param max_duration: Maximal duration in seconds for instance filter.
- :type max_duration: float
- :param min_duration: Minimal duration in seconds for instance filter.
- :type min_duration: float
- :return: Manifest parsing results. List of dict.
- :rtype: list
- :raises IOError: If failed to parse the manifest.
- """
- manifest = []
- for json_line in open(manifest_path):
- try:
- json_data = json.loads(json_line)
- except Exception as e:
- raise IOError("Error reading manifest: %s" % str(e))
- if (json_data["duration"] <= max_duration and
- json_data["duration"] >= min_duration):
- manifest.append(json_data)
- return manifest
diff --git a/datasets/run_all.sh b/datasets/run_all.sh
deleted file mode 100644
index ef2b721fbdc2a18fcbc208730189604e88d7ef2c..0000000000000000000000000000000000000000
--- a/datasets/run_all.sh
+++ /dev/null
@@ -1,13 +0,0 @@
-cd librispeech
-python librispeech.py
-if [ $? -ne 0 ]; then
- echo "Prepare LibriSpeech failed. Terminated."
- exit 1
-fi
-cd -
-
-cat librispeech/manifest.train* | shuf > manifest.train
-cat librispeech/manifest.dev-clean > manifest.dev
-cat librispeech/manifest.test-clean > manifest.test
-
-echo "All done."
diff --git a/datasets/run_noise.sh b/datasets/run_noise.sh
deleted file mode 100644
index 7b27abde47a97b671609f0cd15e81565b3a00d02..0000000000000000000000000000000000000000
--- a/datasets/run_noise.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-cd noise
-python chime3_background.py
-if [ $? -ne 0 ]; then
- echo "Prepare CHiME3 background noise failed. Terminated."
- exit 1
-fi
-cd -
-
-cat noise/manifest.* > manifest.noise
-echo "All done."
diff --git a/datasets/vocab/eng_vocab.txt b/datasets/vocab/eng_vocab.txt
deleted file mode 100644
index 8268f3f3301047f2b4354d60a4bd1d5ef58619a2..0000000000000000000000000000000000000000
--- a/datasets/vocab/eng_vocab.txt
+++ /dev/null
@@ -1,28 +0,0 @@
-'
-
-a
-b
-c
-d
-e
-f
-g
-h
-i
-j
-k
-l
-m
-n
-o
-p
-q
-r
-s
-t
-u
-v
-w
-x
-y
-z
diff --git a/lm/__init__.py b/decoders/__init__.py
similarity index 100%
rename from lm/__init__.py
rename to decoders/__init__.py
diff --git a/decoder.py b/decoders/decoders_deprecated.py
similarity index 91%
rename from decoder.py
rename to decoders/decoders_deprecated.py
index 8f2e0508de79fea30ebc30230e948b15923bdf24..17b28b0d02a22a2e59856156ccd663324e886aed 100644
--- a/decoder.py
+++ b/decoders/decoders_deprecated.py
@@ -9,8 +9,9 @@ from math import log
import multiprocessing
-def ctc_best_path_decoder(probs_seq, vocabulary):
- """Best path decoder, also called argmax decoder or greedy decoder.
+def ctc_greedy_decoder(probs_seq, vocabulary):
+ """CTC greedy (best path) decoder.
+
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
@@ -41,14 +42,16 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
- blank_id,
cutoff_prob=1.0,
+ cutoff_top_n=40,
ext_scoring_func=None,
nproc=False):
- """Beam search decoder for CTC-trained network. It utilizes beam search
- to approximately select top best decoding labels and returning results
- in the descending order. The implementation is based on Prefix
- Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
+ """CTC Beam search decoder.
+
+ It utilizes beam search to approximately select top best decoding
+ labels and returning results in the descending order.
+ The implementation is based on Prefix Beam Search
+ (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
@@ -63,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
- :param blank_id: ID of blank.
- :type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
@@ -84,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")
- # blank_id check
- if not blank_id < len(probs_seq[0]):
- raise ValueError("blank_id shouldn't be greater than probs dimension")
+ # blank_id assign
+ blank_id = len(vocabulary)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
@@ -111,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
- if cutoff_prob < 1.0:
+ if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
@@ -119,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
+ cutoff_len = min(cutoff_len, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
@@ -177,6 +178,8 @@ def ctc_beam_search_decoder(probs_seq,
prob = prob * ext_scoring_func(result)
log_prob = log(prob)
beam_result.append((log_prob, result))
+ else:
+ beam_result.append((float('-inf'), ''))
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
@@ -186,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
- blank_id,
num_processes,
cutoff_prob=1.0,
+ cutoff_top_n=40,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
@@ -199,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
- :param blank_id: ID of blank.
- :type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
@@ -227,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
- args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
- nproc)
+ args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
+ None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
diff --git a/lm/lm_scorer.py b/decoders/scorer_deprecated.py
similarity index 98%
rename from lm/lm_scorer.py
rename to decoders/scorer_deprecated.py
index 463e96d6653b29207fb6105527a1f79c41c7fb84..c6a661030d4363727e259da9c7949e59705d55c8 100644
--- a/lm/lm_scorer.py
+++ b/decoders/scorer_deprecated.py
@@ -8,7 +8,7 @@ import kenlm
import numpy as np
-class LmScorer(object):
+class Scorer(object):
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
model and word count.
diff --git a/decoders/swig/__init__.py b/decoders/swig/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/decoders/swig/_init_paths.py b/decoders/swig/_init_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddabb535be682d95c3c8b73003ea30eed06ca0b0
--- /dev/null
+++ b/decoders/swig/_init_paths.py
@@ -0,0 +1,19 @@
+"""Set up paths for DS2"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import sys
+
+
+def add_path(path):
+ if path not in sys.path:
+ sys.path.insert(0, path)
+
+
+this_dir = os.path.dirname(__file__)
+
+# Add project path to PYTHONPATH
+proj_path = os.path.join(this_dir, '..')
+add_path(proj_path)
diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..624784b05e215782f2264cc6ae4db7eed5b28cae
--- /dev/null
+++ b/decoders/swig/ctc_beam_search_decoder.cpp
@@ -0,0 +1,204 @@
+#include "ctc_beam_search_decoder.h"
+
+#include
+#include
+#include
+#include
+#include