diff --git a/deploy.py b/deploy.py index 02152b499fc322f9bc8e82315142117da77fec3a..70a9b9efee73762d4e5635a73226a0c5f3d3d84e 100644 --- a/deploy.py +++ b/deploy.py @@ -10,8 +10,8 @@ import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 -from swig_ctc_beam_search_decoder import * -from swig_scorer import Scorer +from deploy.swig_decoders import * +from swig_scorer import LmScorer from error_rate import wer import utils import time @@ -85,7 +85,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/en.00.UNKNOWN.klm", + default="lm/data/common_crawl_00.prune01111.trie.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -164,19 +164,19 @@ def infer(): ] # external scorer - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) ## decode and print time_begin = time.time() wer_sum, wer_counter = 0, 0 for i, probs in enumerate(probs_split): beam_result = ctc_beam_search_decoder( - probs.tolist(), - args.beam_size, - data_generator.vocab_list, - len(data_generator.vocab_list), - args.cutoff_prob, - ext_scorer, ) + probs_seq=probs, + beam_size=args.beam_size, + vocabulary=data_generator.vocab_list, + blank_id=len(data_generator.vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) print("\nTarget Transcription:\t%s" % target_transcription[i]) print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1])) diff --git a/deploy/ctc_beam_search_decoder.cpp b/deploy/ctc_decoders.cpp similarity index 94% rename from deploy/ctc_beam_search_decoder.cpp rename to deploy/ctc_decoders.cpp index af6414a97c05307afada28da7e7df1f239f8e631..4cff6d5e544ce04583935331474776b1714f9ad2 100644 --- a/deploy/ctc_beam_search_decoder.cpp +++ b/deploy/ctc_decoders.cpp @@ -4,9 +4,9 @@ #include #include #include -#include "ctc_beam_search_decoder.h" +#include "ctc_decoders.h" -typedef float log_prob_type; +typedef double log_prob_type; template bool pair_comp_first_rev(const std::pair a, const std::pair b) @@ -24,8 +24,8 @@ template T log_sum_exp(T x, T y) { static T num_min = -std::numeric_limits::max(); - if (x <= -num_min) return y; - if (y <= -num_min) return x; + if (x <= num_min) return y; + if (y <= num_min) return x; T xmax = std::max(x, y); return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; } @@ -55,17 +55,13 @@ std::string ctc_best_path_decoder(std::vector > probs_seq, } } max_idx_vec.push_back(max_idx); - std::cout< idx_vec; for (int i=0; i0) && max_idx_vec[i]!=max_idx_vec[i-1])) { - std::cout< > probs_seq, std::string best_path_result; for (int i=0; i > std::vector vocabulary, int blank_id, double cutoff_prob, - Scorer *ext_scorer, + LmScorer *ext_scorer, bool nproc) { // dimension check int num_time_steps = probs_seq.size(); for (int i=0; i vocabulary.size()) { - std::cout<<"Invalid blank_id!"< > vocabulary.end(), " "); int space_id = it - vocabulary.begin(); if(space_id >= vocabulary.size()) { - std::cout<<"The character space is not in the vocabulary!"< > std::vector vocabulary, int blank_id, double cutoff_prob=1.0, - Scorer *ext_scorer=NULL, + LmScorer *ext_scorer=NULL, bool nproc=false ); + /* CTC Best Path Decoder + * + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * Return: + * A vector that each element is a pair of score and decoding result, + * in desending order. */ std::string ctc_best_path_decoder(std::vector > probs_seq, std::vector vocabulary); diff --git a/deploy/ctc_beam_search_decoder.i b/deploy/ctc_decoders.i similarity index 84% rename from deploy/ctc_beam_search_decoder.i rename to deploy/ctc_decoders.i index 09e893d38e415fbed7884c39f33e5230058b76a5..c7d05238e5b0849b4db6217084ac1352fa919f82 100644 --- a/deploy/ctc_beam_search_decoder.i +++ b/deploy/ctc_decoders.i @@ -1,6 +1,6 @@ -%module swig_ctc_beam_search_decoder +%module swig_ctc_decoders %{ -#include "ctc_beam_search_decoder.h" +#include "ctc_decoders.h" %} %include "std_vector.i" @@ -19,4 +19,4 @@ namespace std{ } %import scorer.h -%include "ctc_beam_search_decoder.h" +%include "ctc_decoders.h" diff --git a/deploy/decoder_setup.py b/deploy/decoder_setup.py index 4ed603b25239f1a19611e9ecb347c80fc58568a8..aed45faafc3f7ec84ec3f4949cbc413d36ef2a6d 100644 --- a/deploy/decoder_setup.py +++ b/deploy/decoder_setup.py @@ -34,15 +34,13 @@ if compile_test('lzma.h', 'lzma'): ARGS.append('-DHAVE_XZLIB') LIBS.append('lzma') -os.system('swig -python -c++ ./ctc_beam_search_decoder.i') +os.system('swig -python -c++ ./ctc_decoders.i') ctc_beam_search_decoder_module = [ Extension( - name='_swig_ctc_beam_search_decoder', - sources=FILES + [ - 'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx', - 'ctc_beam_search_decoder.cpp' - ], + name='_swig_ctc_decoders', + sources=FILES + + ['scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp'], language='C++', include_dirs=['.', './kenlm'], libraries=LIBS, @@ -50,8 +48,8 @@ ctc_beam_search_decoder_module = [ ] setup( - name='swig_ctc_beam_search_decoder', + name='swig_ctc_decoders', version='0.1', - description="""CTC beam search decoder""", + description="""CTC decoders""", ext_modules=ctc_beam_search_decoder_module, - py_modules=['swig_ctc_beam_search_decoder'], ) + py_modules=['swig_ctc_decoders'], ) diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index e9a74b989a589ed920ec69b7e5bd5cb15f4a6d11..7a66daad9c3c0783dabf7cd72f6d68b91d621555 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -7,7 +7,7 @@ using namespace lm::ngram; -Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { +LmScorer::LmScorer(float alpha, float beta, std::string lm_model_path) { this->_alpha = alpha; this->_beta = beta; @@ -18,7 +18,7 @@ Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { this->_language_model = LoadVirtual(lm_model_path.c_str()); } -Scorer::~Scorer(){ +LmScorer::~LmScorer(){ delete (lm::base::Model *)this->_language_model; } @@ -57,7 +57,7 @@ inline void strip(std::string &str, char ch=' ') { } } -int Scorer::word_count(std::string sentence) { +int LmScorer::word_count(std::string sentence) { strip(sentence); int cnt = 1; for (int i=0; i_language_model; State state, out_state; lm::FullScoreReturn ret; @@ -84,12 +84,12 @@ double Scorer::language_model_score(std::string sentence) { return log_prob; } -void Scorer::reset_params(float alpha, float beta) { +void LmScorer::reset_params(float alpha, float beta) { this->_alpha = alpha; this->_beta = beta; } -double Scorer::get_score(std::string sentence, bool log) { +double LmScorer::get_score(std::string sentence, bool log) { double lm_score = language_model_score(sentence); int word_cnt = word_count(sentence); diff --git a/deploy/scorer.h b/deploy/scorer.h index a18e119bcf155d2380972c8560d588d0ccf43efc..90a1a84a0a06314e6457b48344ea7487af41dd11 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -8,10 +8,10 @@ * count and language model scoring. * Example: - * Scorer ext_scorer(alpha, beta, "path_to_language_model.klm"); + * LmScorer ext_scorer(alpha, beta, "path_to_language_model.klm"); * double score = ext_scorer.get_score("sentence_to_score"); */ -class Scorer{ +class LmScorer{ private: float _alpha; float _beta; @@ -23,9 +23,9 @@ private: double language_model_score(std::string); public: - Scorer(){} - Scorer(float alpha, float beta, std::string lm_model_path); - ~Scorer(); + LmScorer(){} + LmScorer(float alpha, float beta, std::string lm_model_path); + ~LmScorer(); // reset params alpha & beta void reset_params(float alpha, float beta); diff --git a/deploy/swig_decoders.py b/deploy/swig_decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4a39252b2012d39ac500f76471335e7764f291 --- /dev/null +++ b/deploy/swig_decoders.py @@ -0,0 +1,86 @@ +"""Wrapper for various CTC decoders in SWIG.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import swig_ctc_decoders +import multiprocessing + + +def ctc_best_path_decoder(probs_seq, vocabulary): + """Wrapper for ctc best path decoder in swig. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. + :type probs_seq: 2-D list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :return: Decoding result string. + :rtype: basestring + """ + return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(), + vocabulary) + + +def ctc_beam_search_decoder( + probs_seq, + beam_size, + vocabulary, + blank_id, + cutoff_prob=1.0, + ext_scoring_func=None, ): + """Wrapper for CTC Beam Search Decoder. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. + :type probs_seq: 2-D list + :param beam_size: Width for beam search. + :type beam_size: int + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param blank_id: ID of blank. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_func: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list + """ + return swig_ctc_decoders.ctc_beam_search_decoder( + probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, + ext_scoring_func) + + +def ctc_beam_search_decoder_batch(probs_split, + beam_size, + vocabulary, + blank_id, + num_processes, + cutoff_prob=1.0, + ext_scoring_func=None): + """Wrapper for CTC beam search decoder in batch + """ + + # TODO: to resolve PicklingError + + if not num_processes > 0: + raise ValueError("Number of processes must be positive!") + + pool = multiprocessing.Pool(processes=num_processes) + results = [] + for i, probs_list in enumerate(probs_split): + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, + ext_scoring_func) + results.append(pool.apply_async(ctc_beam_search_decoder, args)) + + pool.close() + pool.join() + beam_search_results = [result.get() for result in results] + return beam_search_results