提交 dad406a4 编写于 作者: Y Yibing Liu

add the support of parallel beam search decoding in deployment

上级 d1189a79
......@@ -18,7 +18,7 @@ import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_samples",
default=10,
default=32,
type=int,
help="Number of samples for inference. (default: %(default)s)")
parser.add_argument(
......@@ -46,6 +46,11 @@ parser.add_argument(
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
......@@ -70,8 +75,8 @@ parser.add_argument(
"--decode_method",
default='beam_search',
type=str,
help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)"
)
help="Method for ctc decoding: beam_search or beam_search_batch. "
"(default: %(default)s)")
parser.add_argument(
"--beam_size",
default=200,
......@@ -169,15 +174,28 @@ def infer():
## decode and print
time_begin = time.time()
wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
probs_seq=probs,
batch_beam_results = []
if args.decode_method == 'beam_search':
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
probs_seq=probs,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
batch_beam_results += [beam_result]
else:
batch_beam_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
for i, beam_result in enumerate(batch_beam_results):
print("\nTarget Transcription:\t%s" % target_transcription[i])
print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1]))
wer_cur = wer(target_transcription[i], beam_result[0][1])
......@@ -185,6 +203,7 @@ def infer():
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
time_end = time.time()
print("total time = %f" % (time_end - time_begin))
......
### Installation
The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`)
The build of the decoder for deployment depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`)
- [**KenLM**](https://github.com/kpu/kenlm/): Faster and Smaller Language Model Queries
```shell
git clone https://github.com/kpu/kenlm.git
```
- [**OpenFst**](http://www.openfst.org/twiki/bin/view/FST/WebHome): A library for finite-state transducers
```shell
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool
```shell
git clone https://github.com/progschj/ThreadPool.git
```
Then run the setup
```shell
......
......@@ -6,6 +6,7 @@
#include <limits>
#include "ctc_decoders.h"
#include "decoder_utils.h"
#include "ThreadPool.h"
typedef double log_prob_type;
......@@ -33,7 +34,8 @@ T log_sum_exp(T x, T y)
}
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) {
std::vector<std::string> vocabulary)
{
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
......@@ -83,8 +85,8 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob,
Scorer *ext_scorer,
bool nproc) {
Scorer *ext_scorer)
{
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
......@@ -260,3 +262,39 @@ std::vector<std::pair<double, std::string> >
pair_comp_first_rev<double, std::string>);
return beam_result;
}
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
int num_processes,
double cutoff_prob,
Scorer *ext_scorer
)
{
if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1);
}
// thread pool
ThreadPool pool(num_processes);
// number of samples
int batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) {
res.emplace_back(
pool.enqueue(ctc_beam_search_decoder, probs_split[i],
beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer)
);
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
......@@ -6,8 +6,20 @@
#include <utility>
#include "scorer.h"
/* CTC Beam Search Decoder, the interface is consistent with the
* original decoder in Python version.
/* CTC Best Path Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
......@@ -17,7 +29,6 @@
* blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* nproc: Whether this function used in multiprocessing.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
......@@ -28,21 +39,35 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
bool nproc=false
Scorer *ext_scorer=NULL
);
/* CTC Best Path Decoder
*
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version.
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary);
* A 2-D vector that each element is a vector of decoding result for one
* sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
int num_processes,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
......@@ -17,6 +17,8 @@ namespace std{
%template(Pair) std::pair<float, std::string>;
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
}
%import decoder_utils.h
......
......@@ -36,12 +36,12 @@ if compile_test('lzma.h', 'lzma'):
os.system('swig -python -c++ ./decoders.i')
ctc_beam_search_decoder_module = [
decoders_module = [
Extension(
name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='C++',
include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'],
include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'],
libraries=LIBS,
extra_compile_args=ARGS)
]
......@@ -50,5 +50,5 @@ setup(
name='swig_decoders',
version='0.1',
description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module,
ext_modules=decoders_module,
py_modules=['swig_decoders'], )
......@@ -4,7 +4,6 @@ from __future__ import division
from __future__ import print_function
import swig_decoders
import multiprocessing
class Scorer(swig_decoders.Scorer):
......@@ -39,14 +38,13 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder(
probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None, ):
"""Wrapper for CTC Beam Search Decoder.
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
......@@ -81,24 +79,34 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None):
"""Wrapper for CTC beam search decoder in batch
"""
# TODO: to resolve PicklingError
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
"""Wrapper for the batched CTC beam search decoder.
pool = Pool(processes=num_processes)
results = []
args_list = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
args_list.append(args)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
return swig_decoders.ctc_beam_search_decoder_batch(
probs_split, beam_size, vocabulary, blank_id, num_processes,
cutoff_prob, ext_scoring_func)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册