提交 94a68116 编写于 作者: Y Yibing Liu

code cleanup for the deployment decoder

上级 2e76f82c
...@@ -6,35 +6,47 @@ ...@@ -6,35 +6,47 @@
#include "ctc_beam_search_decoder.h" #include "ctc_beam_search_decoder.h"
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) { bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.first > b.first; return a.first > b.first;
} }
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) { bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.second > b.second; return a.second > b.second;
} }
/* CTC beam search decoder in C++, the interface is consistent with the original
decoder in Python version.
*/
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob, double cutoff_prob,
Scorer *ext_scorer, Scorer *ext_scorer,
bool nproc bool nproc) {
) // dimension check
{
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl;
exit(1);
}
}
// blank_id check
if (blank_id > vocabulary.size()) {
std::cout<<"Invalid blank_id!"<<std::endl;
exit(1);
}
// assign space ID // assign space ID
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
int space_id = it-vocabulary.begin(); vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
if(space_id >= vocabulary.size()) { if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!"; std::cout<<"The character space is not in the vocabulary!"<<std::endl;
exit(1); exit(1);
} }
...@@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> > ...@@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> >
} }
// pruning of vacobulary // pruning of vacobulary
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); std::sort(prob_idx.begin(), prob_idx.end(),
pair_comp_second_rev<int, double>);
float cum_prob = 0.0; float cum_prob = 0.0;
int cutoff_len = 0; int cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) { for (int i=0; i<prob_idx.size(); i++) {
...@@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> > ...@@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> >
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
} }
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len); prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
prob_idx.begin() + cutoff_len);
} }
// extend prefix // extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin(); for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
...@@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> > ...@@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> >
int c = prob_idx[index].first; int c = prob_idx[index].first;
double prob_c = prob_idx[index].second; double prob_c = prob_idx[index].second;
if (c == blank_id) { if (c == blank_id) {
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]); probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]);
} else { } else {
std::string last_char = l.substr(l.size()-1, 1); std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c]; std::string new_char = vocabulary[c];
std::string l_plus = l+new_char; std::string l_plus = l + new_char;
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) { if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0; probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
...@@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> > ...@@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> >
probs_nb_cur[l_plus] += prob_c * ( probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]); probs_b_prev[l] + probs_nb_prev[l]);
} }
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus]; prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus];
} }
} }
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l];
} }
probs_b_prev = probs_b_cur; probs_b_prev = probs_b_cur;
probs_nb_prev = probs_nb_cur; probs_nb_prev = probs_nb_cur;
std::vector<std::pair<std::string, double> > std::vector<std::pair<std::string, double> >
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); prefix_vec_next(prefix_set_next.begin(),
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>); prefix_set_next.end());
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size(); std::sort(prefix_vec_next.begin(),
prefix_vec_next.end(),
pair_comp_second_rev<std::string, double>);
int k = beam_size<prefix_vec_next.size() ? beam_size:prefix_vec_next.size();
prefix_set_prev = std::map<std::string, double> prefix_set_prev = std::map<std::string, double>
(prefix_vec_next.begin(), prefix_vec_next.begin()+k); (prefix_vec_next.begin(), prefix_vec_next.begin()+k);
} }
...@@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> > ...@@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> >
} }
} }
// sort the result and return // sort the result and return
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>); std::sort(beam_result.begin(), beam_result.end(),
pair_comp_first_rev<double, std::string>);
return beam_result; return beam_result;
} }
...@@ -6,14 +6,30 @@ ...@@ -6,14 +6,30 @@
#include <utility> #include <utility>
#include "scorer.h" #include "scorer.h"
std::vector<std::pair<double, std::string> > /* CTC Beam Search Decoder, the interface is consistent with the
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, * original decoder in Python version.
int beam_size,
std::vector<std::string> vocabulary, * Parameters:
int blank_id=0, * probs_seq: 2-D vector that each element is a vector of probabilities
double cutoff_prob=1.0, * over vocabulary of one time step.
Scorer *ext_scorer=NULL, * beam_size: The width of beam search.
bool nproc=false * vocabulary: A vector of vocabulary.
); * blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* nproc: Whether this function used in multiprocessing.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
bool nproc=false
);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_
...@@ -10,8 +10,8 @@ def compile_test(header, library): ...@@ -10,8 +10,8 @@ def compile_test(header, library):
return os.system(command) == 0 return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob( FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'util/double-conversion/*.cc') 'kenlm/util/double-conversion/*.cc')
FILES = [ FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
] ]
...@@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [ ...@@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [
'ctc_beam_search_decoder.cpp' 'ctc_beam_search_decoder.cpp'
], ],
language='C++', language='C++',
include_dirs=['.'], include_dirs=['.', './kenlm'],
libraries=LIBS, libraries=LIBS,
extra_compile_args=ARGS) extra_compile_args=ARGS)
] ]
...@@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [ ...@@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [
setup( setup(
name='swig_ctc_beam_search_decoder', name='swig_ctc_beam_search_decoder',
version='0.1', version='0.1',
author='Yibing Liu',
description="""CTC beam search decoder""", description="""CTC beam search decoder""",
ext_modules=ctc_beam_search_decoder_module, ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_beam_search_decoder'], ) py_modules=['swig_ctc_beam_search_decoder'], )
#include <iostream> #include <iostream>
#include "scorer.h" #include "scorer.h"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/tokenize_piece.hh" #include "util/tokenize_piece.hh"
...@@ -17,6 +16,13 @@ Scorer::~Scorer(){ ...@@ -17,6 +16,13 @@ Scorer::~Scorer(){
delete (Model *)this->_language_model; delete (Model *)this->_language_model;
} }
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') { inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return; if (str.size() == 0) return;
int start = 0; int start = 0;
...@@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) { ...@@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) {
} }
//log10 prob //log10 prob
double log_prob = ret.prob; double log_prob = ret.prob;
return log_prob; return log_prob;
} }
void Scorer::reset_params(float alpha, float beta) {
this->_alpha = alpha;
this->_beta = beta;
}
double Scorer::get_score(std::string sentence) { double Scorer::get_score(std::string sentence) {
double lm_score = language_model_score(sentence); double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence); int word_cnt = word_count(sentence);
......
...@@ -3,20 +3,34 @@ ...@@ -3,20 +3,34 @@
#include <string> #include <string>
/* External scorer to evaluate a prefix or a complete sentence
* when a new word appended during decoding, consisting of word
* count and language model scoring.
* Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score");
*/
class Scorer{ class Scorer{
private: private:
float _alpha; float _alpha;
float _beta; float _beta;
void *_language_model; void *_language_model;
// word insertion term
int word_count(std::string);
// n-gram language model scoring
double language_model_score(std::string);
public: public:
Scorer(){} Scorer(){}
Scorer(float alpha, float beta, std::string lm_model_path); Scorer(float alpha, float beta, std::string lm_model_path);
~Scorer(); ~Scorer();
int word_count(std::string);
double language_model_score(std::string); // reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string); double get_score(std::string);
}; };
#endif #endif //SCORER_H_
...@@ -10,8 +10,8 @@ def compile_test(header, library): ...@@ -10,8 +10,8 @@ def compile_test(header, library):
return os.system(command) == 0 return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob( FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'util/double-conversion/*.cc') 'kenlm/util/double-conversion/*.cc')
FILES = [ FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
] ]
...@@ -41,7 +41,7 @@ ext_modules = [ ...@@ -41,7 +41,7 @@ ext_modules = [
name='_swig_scorer', name='_swig_scorer',
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'], sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
language='C++', language='C++',
include_dirs=['.'], include_dirs=['.', './kenlm'],
libraries=LIBS, libraries=LIBS,
extra_compile_args=ARGS) extra_compile_args=ARGS)
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册