提交 a2ddfe8d 编写于 作者: Y Yibing Liu

clean up code & update README for decoder in deployment

上级 221a597f
......@@ -9,7 +9,7 @@ import distutils.util
import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from layer import deep_speech2
from deploy.swig_decoders_wrapper import *
from error_rate import wer
import utils
......@@ -79,7 +79,7 @@ parser.add_argument(
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=20,
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
......@@ -89,8 +89,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="/home/work/liuyibing/lm_bak/common_crawl_00.prune01111.trie.klm",
#default="ptb_all.arpa",
default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
......@@ -136,14 +135,13 @@ def infer():
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs = deep_speech2(
output_probs, _ = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
is_inference=True)
rnn_size=args.rnn_layer_size)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
......@@ -159,8 +157,10 @@ def infer():
infer_data = batch_reader().next()
# run inference
infer_results = paddle.infer(
output_layer=output_probs, parameters=parameters, input=infer_data)
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) // len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
......@@ -178,17 +178,29 @@ def infer():
ext_scorer = Scorer(
alpha=args.alpha, beta=args.beta, model_path=args.language_model_path)
# from unicode to string
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# The below two steps, i.e. setting char map and filling dictionary of
# FST will be completed implicitly when ext_scorer first used.But to save
# the time of decoding the first audio sample, they are done in advance.
ext_scorer.set_char_map(vocab_list)
# only for ward based language model
ext_scorer.fill_dictionary(True)
# for word error rate metric
wer_sum, wer_counter = 0.0, 0
## decode and print
time_begin = time.time()
wer_sum, wer_counter = 0, 0
batch_beam_results = []
if args.decode_method == 'beam_search':
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
probs_seq=probs,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
vocabulary=vocab_list,
blank_id=len(vocab_list),
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
ext_scoring_func=ext_scorer, )
......@@ -197,8 +209,8 @@ def infer():
batch_beam_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
vocabulary=vocab_list,
blank_id=len(vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
......@@ -213,8 +225,7 @@ def infer():
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
time_end = time.time()
print("total time = %f" % (time_end - time_begin))
print("time for decoding = %f" % (time.time() - time_begin))
def main():
......
The decoders for deployment developed in C++ are a better alternative for the prototype decoders in Pytthon, with more powerful performance in both speed and accuracy.
### Installation
The build of the decoder for deployment depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`)
The build depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`)
- [**KenLM**](https://github.com/kpu/kenlm/): Faster and Smaller Language Model Queries
......@@ -14,7 +18,6 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
- [**SWIG**](http://www.swig.org): Compiling for python interface requires swig, please make sure swig being installed.
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool
......@@ -22,6 +25,8 @@ tar -xzvf openfst-1.6.3.tar.gz
git clone https://github.com/progschj/ThreadPool.git
```
- [**SWIG**](http://www.swig.org): A tool that provides the Python interface for the decoders, please make sure it being installed.
Then run the setup
```shell
......@@ -29,7 +34,9 @@ python setup.py install --num_processes 4
cd ..
```
### Deployment
### Usage
The decoders for deployment share almost the same interface with the prototye decoders in Python. After the installation succeeds, these decoders are very convenient for call in Python, and a complete example in ```deploy.py``` can be refered.
For GPU deployment
......
......@@ -90,26 +90,32 @@ std::vector<std::pair<double, std::string> >
space_id = -2;
}
// init
// init prefixes' root
PathTrie root;
root._score = root._log_prob_b_prev = 0.0;
std::vector<PathTrie*> prefixes;
prefixes.push_back(&root);
if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer->dictionary == nullptr) {
// TODO: init dictionary
if ( ext_scorer != nullptr) {
if (ext_scorer->is_char_map_empty()) {
ext_scorer->set_char_map(vocabulary);
// add_space should be true?
ext_scorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
if (!ext_scorer->is_character_based()) {
if (ext_scorer->dictionary == nullptr) {
// fill dictionary for fst
ext_scorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst*>
(ext_scorer->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>
(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
}
// prefix search over time
for (int time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double> > prob_idx;
......@@ -147,12 +153,12 @@ std::vector<std::pair<double, std::string> >
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<int, float> > log_prob_idx;
for (int i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float>
(prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
// loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
......@@ -214,15 +220,14 @@ std::vector<std::pair<double, std::string> >
prefix_new->_log_prob_nb_cur = log_sum_exp(
prefix_new->_log_prob_nb_cur, log_p);
}
}
} // end of loop over prefix
} // end of loop over chars
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// preserve top beam_size prefixes
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
......@@ -233,7 +238,7 @@ std::vector<std::pair<double, std::string> >
prefixes[i]->remove();
}
}
}
} // end of loop over time
// compute aproximate ctc score as the return score
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
......@@ -300,14 +305,19 @@ std::vector<std::vector<std::pair<double, std::string> > >
ThreadPool pool(num_processes);
// number of samples
int batch_size = probs_split.size();
// dictionary init
if ( ext_scorer != nullptr
&& !ext_scorer->is_character_based()
&& ext_scorer->dictionary == nullptr) {
// init dictionary
ext_scorer->set_char_map(vocabulary);
ext_scorer->fill_dictionary(true);
// scorer filling up
if ( ext_scorer != nullptr) {
if (ext_scorer->is_char_map_empty()) {
ext_scorer->set_char_map(vocabulary);
}
if(!ext_scorer->is_character_based()
&& ext_scorer->dictionary == nullptr) {
// init dictionary
ext_scorer->fill_dictionary(true);
}
}
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) {
......@@ -317,6 +327,7 @@ std::vector<std::vector<std::pair<double, std::string> > >
cutoff_top_n, ext_scorer)
);
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string> > > batch_results;
for (int i = 0; i < batch_size; i++) {
......
......@@ -27,7 +27,8 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix.
* Return:
* A vector that each element is a pair of score and decoding result,
......@@ -54,7 +55,8 @@ std::vector<std::pair<double, std::string> >
* vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability of pruning
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix.
* Return:
* A 2-D vector that each element is a vector of decoding result for one
......
......@@ -11,10 +11,6 @@ size_t get_utf8_str_len(const std::string& str) {
return str_len;
}
//------------------------------------------------------
//Splits string into vector of strings representing
//UTF-8 characters (not same as chars)
//------------------------------------------------------
std::vector<std::string> split_utf8_str(const std::string& str)
{
std::vector<std::string> result;
......@@ -37,9 +33,6 @@ std::vector<std::string> split_utf8_str(const std::string& str)
return result;
}
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
......@@ -60,9 +53,6 @@ std::vector<std::string> split_str(const std::string &s,
return result;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
bool prefix_compare(const PathTrie* x, const PathTrie* y) {
if (x->_score == y->_score) {
if (x->_character == y->_character) {
......@@ -73,11 +63,8 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y) {
} else {
return x->_score > y->_score;
}
} //---------- End path_compare ---------------------------
}
// --------------------------------------------------------------
// Adds word to fst without copying entire dictionary
// --------------------------------------------------------------
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary) {
if (dictionary->NumStates() == 0) {
......@@ -93,15 +80,12 @@ void add_word_to_fst(const std::vector<int>& word,
src = dst;
}
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
} // ------------ End of add_word_to_fst -----------------------
}
// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
bool add_word_to_dictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
int SPACE_ID,
fst::StdVectorFst* dictionary) {
auto characters = split_utf8_str(word);
......@@ -109,7 +93,7 @@ bool add_word_to_dictionary(const std::string& word,
for (auto& c : characters) {
if (c == " ") {
int_word.push_back(SPACE);
int_word.push_back(SPACE_ID);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
......@@ -121,9 +105,9 @@ bool add_word_to_dictionary(const std::string& word,
}
if (add_space) {
int_word.push_back(SPACE);
int_word.push_back(SPACE_ID);
}
add_word_to_fst(int_word, dictionary);
return true;
} // -------------- End of addWordToDictionary ------------
}
......@@ -7,6 +7,7 @@
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b)
......@@ -31,7 +32,6 @@ T log_sum_exp(const T &x, const T &y)
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
// Functor for prefix comparsion
bool prefix_compare(const PathTrie* x, const PathTrie* y);
......@@ -39,17 +39,24 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y);
// See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str);
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
// Splits string into vector of strings representing
// UTF-8 characters (not same as chars)
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary);
// Add a word in string to dictionary
bool add_word_to_dictionary(const std::string& word,
const std::unordered_map<std::string, int>& char_map,
bool add_space,
int SPACE,
int SPACE_ID,
fst::StdVectorFst* dictionary);
#endif // DECODER_UTILS_H
......@@ -86,7 +86,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps /*= std::numeric_limits<size_t>::max() */) {
size_t max_steps) {
if (_character == stop ||
_character == _ROOT ||
output.size() == max_steps) {
......
......@@ -32,34 +32,48 @@ public:
// Example:
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_log_cond_prob("this a sentence");
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{
public:
Scorer(double alpha, double beta, const std::string& lm_path);
~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words);
double get_sent_log_prob(const std::vector<std::string>& words);
size_t get_max_order() { return _max_order; }
bool is_char_map_empty() {return _char_map.size() == 0; }
bool is_character_based() { return _is_character_based; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
// make ngram
std::vector<std::string> make_ngram(PathTrie* prefix);
// fill dictionary for fst
void fill_dictionary(bool add_space);
// set char map
void set_char_map(std::vector<std::string> char_list);
std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder
double alpha;
double beta;
// fst dictionary
void* dictionary;
protected:
void load_LM(const char* filename);
double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input);
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册