提交 f3f5dad8 编写于 作者: Y Yibing Liu

format header includes & update setup info

上级 7d0458c7
......@@ -82,6 +82,16 @@ sh run.sh
cd ..
```
### Setup decoders
```shell
cd models/swig_decoders
sh setup.sh
cd ../..
```
These commands will install the decoders that translate the ouptut probability vectors of DS2 model to text data, incuding CTC greedy decoder, CTC beam search decoder and its batch version.
### Inference
For GPU inference
......
"""Deployment for DeepSpeech2 model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import gzip
import distutils.util
import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from layer import deep_speech2
from deploy.swig_decoders_wrapper import *
from error_rate import wer
import utils
import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=10,
type=int,
help="Number of samples for inference. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='datasets/manifest.test',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search',
type=str,
help="Method for ctc decoding: beam_search or beam_search_batch. "
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--num_results_per_sample",
default=1,
type=int,
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=1.5,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.3,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=1.0,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
parser.add_argument(
"--cutoff_top_n",
default=40,
type=int,
help="The cutoff number of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args()
def infer():
"""Deployment for DeepSpeech2."""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config
# paddle.data_type.dense_array is used for variable batch input.
# The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
audio_data = paddle.layer.data(
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs, _ = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_filepath))
# prepare infer data
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
min_batch_size=1,
sortagrad=False,
shuffle_method=None)
infer_data = batch_reader().next()
# run inference
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) // len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(len(infer_data))
]
# targe transcription
target_transcription = [
''.join(
[data_generator.vocab_list[index] for index in infer_data[i][1]])
for i, probs in enumerate(probs_split)
]
# external scorer
ext_scorer = Scorer(
alpha=args.alpha, beta=args.beta, model_path=args.language_model_path)
# from unicode to string
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# The below two steps, i.e. setting char map and filling dictionary of
# FST will be completed implicitly when ext_scorer first used.But to save
# the time of decoding the first audio sample, they are done in advance.
ext_scorer.set_char_map(vocab_list)
# only for ward based language model
ext_scorer.fill_dictionary(True)
# for word error rate metric
wer_sum, wer_counter = 0.0, 0
## decode and print
time_begin = time.time()
batch_beam_results = []
if args.decode_method == 'beam_search':
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
probs_seq=probs,
beam_size=args.beam_size,
vocabulary=vocab_list,
blank_id=len(vocab_list),
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
ext_scoring_func=ext_scorer, )
batch_beam_results += [beam_result]
else:
batch_beam_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
beam_size=args.beam_size,
vocabulary=vocab_list,
blank_id=len(vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
ext_scoring_func=ext_scorer, )
for i, beam_result in enumerate(batch_beam_results):
print("\nTarget Transcription:\t%s" % target_transcription[i])
print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1]))
wer_cur = wer(target_transcription[i], beam_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
print("time for decoding = %f" % (time.time() - time_begin))
def main():
utils.print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
infer()
if __name__ == '__main__':
main()
The decoders for deployment developed in C++ are a better alternative for the prototype decoders in Pytthon, with more powerful performance in both speed and accuracy.
### Installation
The build depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`)
- [**KenLM**](https://github.com/kpu/kenlm/): Faster and Smaller Language Model Queries
```shell
git clone https://github.com/kpu/kenlm.git
```
- [**OpenFst**](http://www.openfst.org/twiki/bin/view/FST/WebHome): A library for finite-state transducers
```shell
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool
```shell
git clone https://github.com/progschj/ThreadPool.git
```
- [**SWIG**](http://www.swig.org): A tool that provides the Python interface for the decoders, please make sure it being installed.
Then run the setup
```shell
python setup.py install --num_processes 4
cd ..
```
### Usage
The decoders for deployment share almost the same interface with the prototye decoders in Python. After the installation succeeds, these decoders are very convenient for call in Python, and a complete example in ```deploy.py``` can be refered.
For GPU deployment
```
CUDA_VISIBLE_DEVICES=0 python deploy.py
```
For CPU deployment
```
python deploy.py --use_gpu=False
```
More help for arguments
```
python deploy.py --help
```
#include "ctc_decoders.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "fst/fstlib.h"
#include "ThreadPool.h"
#include "decoder_utils.h"
#include "fst/fstlib.h"
#include "path_trie.h"
std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary) {
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary) {
// dimension check
int num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) {
......@@ -56,7 +60,7 @@ std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
}
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<double>> probs_seq,
const std::vector<std::vector<double>>& probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
......@@ -64,7 +68,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
int cutoff_top_n,
Scorer *extscorer) {
// dimension check
int num_time_steps = probs_seq.size();
size_t num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << " The shape of probs_seq does not match"
......@@ -278,9 +282,9 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
const std::vector<std::vector<std::vector<double>>>& probs_split,
int beam_size,
std::vector<std::string> vocabulary,
const std::vector<std::string>& vocabulary,
int blank_id,
int num_processes,
double cutoff_prob,
......
......@@ -4,6 +4,7 @@
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
/* CTC Best Path Decoder
......@@ -16,8 +17,9 @@
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary);
std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary);
/* CTC Beam Search Decoder
......@@ -35,7 +37,7 @@ std::string ctc_greedy_decoder(std::vector<std::vector<double>> probs_seq,
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<double>> probs_seq,
const std::vector<std::vector<double>>& probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
......@@ -43,8 +45,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
int cutoff_top_n = 40,
Scorer *ext_scorer = NULL);
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version.
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
......@@ -63,9 +64,9 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
const std::vector<std::vector<std::vector<double>>>& probs_split,
int beam_size,
std::vector<std::string> vocabulary,
const std::vector<std::string>& vocabulary,
int blank_id,
int num_processes,
double cutoff_prob = 1.0,
......
#include "decoder_utils.h"
#include <algorithm>
#include <cmath>
#include <limits>
......
#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
......@@ -5,7 +7,6 @@
#include <vector>
#include "decoder_utils.h"
#include "path_trie.h"
PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF;
......
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#include <fst/fstlib.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include <fst/fstlib.h>
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
......@@ -45,12 +45,12 @@ public:
private:
int _ROOT;
bool _exists;
bool _has_dictionary;
std::vector<std::pair<int, PathTrie*>> _children;
fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state;
bool _has_dictionary;
std::shared_ptr<FSTMATCH> _matcher;
};
......
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "decoder_utils.h"
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
......@@ -122,7 +125,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
return words;
}
void Scorer::set_char_map(std::vector<std::string> char_list) {
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
_char_list = char_list;
_char_map.clear();
......
......@@ -5,12 +5,14 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "path_trie.h"
#include "util/string_piece.hh"
#include "path_trie.h"
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
......@@ -28,11 +30,13 @@ public:
std::vector<std::string> vocabulary;
};
// External scorer to query languange score for n-gram or sentence.
// Example:
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
/* External scorer to query languange score for n-gram or sentence.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
Scorer(double alpha, double beta, const std::string& lm_path);
......@@ -58,7 +62,7 @@ public:
void fill_dictionary(bool add_space);
// set char map
void set_char_map(std::vector<std::string> char_list);
void set_char_map(const std::vector<std::string>& char_list);
std::vector<std::string> split_labels(const std::vector<int>& labels);
......
#!/bin/bash
if [ ! -d kenlm ]; then
git clone https://github.com/luotao1/kenlm.git
echo -e "\n"
fi
if [ ! -d openfst-1.6.3 ]; then
echo "Download and extract openfst ..."
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
echo -e "\n"
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
echo "Install decoders ..."
python setup.py install --num_processes 4
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册