提交 94a68116 编写于 作者: Y Yibing Liu

code cleanup for the deployment decoder

上级 2e76f82c
......@@ -6,35 +6,47 @@
#include "ctc_beam_search_decoder.h"
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
return a.second > b.second;
}
/* CTC beam search decoder in C++, the interface is consistent with the original
decoder in Python version.
*/
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob,
Scorer *ext_scorer,
bool nproc
)
{
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob,
Scorer *ext_scorer,
bool nproc) {
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl;
exit(1);
}
}
// blank_id check
if (blank_id > vocabulary.size()) {
std::cout<<"Invalid blank_id!"<<std::endl;
exit(1);
}
// assign space ID
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it-vocabulary.begin();
std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!";
std::cout<<"The character space is not in the vocabulary!"<<std::endl;
exit(1);
}
......@@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> >
}
// pruning of vacobulary
if (cutoff_prob < 1.0) {
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
std::sort(prob_idx.begin(), prob_idx.end(),
pair_comp_second_rev<int, double>);
float cum_prob = 0.0;
int cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) {
......@@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> >
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
prob_idx.begin() + cutoff_len);
}
// extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
......@@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> >
int c = prob_idx[index].first;
double prob_c = prob_idx[index].second;
if (c == blank_id) {
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]);
probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]);
} else {
std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c];
std::string l_plus = l+new_char;
std::string l_plus = l + new_char;
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
......@@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> >
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]);
}
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus];
prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus];
}
}
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l];
}
probs_b_prev = probs_b_cur;
probs_nb_prev = probs_nb_cur;
std::vector<std::pair<std::string, double> >
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
prefix_vec_next(prefix_set_next.begin(),
prefix_set_next.end());
std::sort(prefix_vec_next.begin(),
prefix_vec_next.end(),
pair_comp_second_rev<std::string, double>);
int k = beam_size<prefix_vec_next.size() ? beam_size:prefix_vec_next.size();
prefix_set_prev = std::map<std::string, double>
(prefix_vec_next.begin(), prefix_vec_next.begin()+k);
}
......@@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> >
}
}
// sort the result and return
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>);
std::sort(beam_result.begin(), beam_result.end(),
pair_comp_first_rev<double, std::string>);
return beam_result;
}
......@@ -6,14 +6,30 @@
#include <utility>
#include "scorer.h"
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id=0,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
bool nproc=false
);
/* CTC Beam Search Decoder, the interface is consistent with the
* original decoder in Python version.
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* nproc: Whether this function used in multiprocessing.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
bool nproc=false
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
......@@ -10,8 +10,8 @@ def compile_test(header, library):
return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
'util/double-conversion/*.cc')
FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'kenlm/util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
......@@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [
'ctc_beam_search_decoder.cpp'
],
language='C++',
include_dirs=['.'],
include_dirs=['.', './kenlm'],
libraries=LIBS,
extra_compile_args=ARGS)
]
......@@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [
setup(
name='swig_ctc_beam_search_decoder',
version='0.1',
author='Yibing Liu',
description="""CTC beam search decoder""",
ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_beam_search_decoder'], )
#include <iostream>
#include "scorer.h"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
......@@ -17,6 +16,13 @@ Scorer::~Scorer(){
delete (Model *)this->_language_model;
}
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
......@@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) {
}
//log10 prob
double log_prob = ret.prob;
return log_prob;
}
void Scorer::reset_params(float alpha, float beta) {
this->_alpha = alpha;
this->_beta = beta;
}
double Scorer::get_score(std::string sentence) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);
......
......@@ -3,20 +3,34 @@
#include <string>
/* External scorer to evaluate a prefix or a complete sentence
* when a new word appended during decoding, consisting of word
* count and language model scoring.
* Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score");
*/
class Scorer{
private:
float _alpha;
float _beta;
void *_language_model;
// word insertion term
int word_count(std::string);
// n-gram language model scoring
double language_model_score(std::string);
public:
Scorer(){}
Scorer(float alpha, float beta, std::string lm_model_path);
~Scorer();
int word_count(std::string);
double language_model_score(std::string);
// reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string);
};
#endif
#endif //SCORER_H_
......@@ -10,8 +10,8 @@ def compile_test(header, library):
return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
'util/double-conversion/*.cc')
FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'kenlm/util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
......@@ -41,7 +41,7 @@ ext_modules = [
name='_swig_scorer',
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
language='C++',
include_dirs=['.'],
include_dirs=['.', './kenlm'],
libraries=LIBS,
extra_compile_args=ARGS)
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册