提交 2f15a787 编写于 作者: Y Yibing Liu

add initial files for deployment

上级
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
#include "ctc_beam_search_decoder.h"
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
return a.first > b.first;
}
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
return a.second > b.second;
}
/* CTC beam search decoder in C++, the interface is consistent with the original
decoder in Python version.
*/
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob,
Scorer *ext_scorer,
bool nproc
)
{
int num_time_steps = probs_seq.size();
// assign space ID
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it-vocabulary.begin();
if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!";
exit(1);
}
// initialize
// two sets containing selected and candidate prefixes respectively
std::map<std::string, double> prefix_set_prev, prefix_set_next;
// probability of prefixes ending with blank and non-blank
std::map<std::string, double> probs_b_prev, probs_nb_prev;
std::map<std::string, double> probs_b_cur, probs_nb_cur;
prefix_set_prev["\t"] = 1.0;
probs_b_prev["\t"] = 1.0;
probs_nb_prev["\t"] = 0.0;
for (int time_step=0; time_step<num_time_steps; time_step++) {
prefix_set_next.clear();
probs_b_cur.clear();
probs_nb_cur.clear();
std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double> > prob_idx;
for (int i=0; i<prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
}
// pruning of vacobulary
if (cutoff_prob < 1.0) {
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
float cum_prob = 0.0;
int cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
}
// extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) {
std::string l = it->first;
if( prefix_set_next.find(l) == prefix_set_next.end()) {
probs_b_cur[l] = probs_nb_cur[l] = 0.0;
}
for (int index=0; index<prob_idx.size(); index++) {
int c = prob_idx[index].first;
double prob_c = prob_idx[index].second;
if (c == blank_id) {
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]);
} else {
std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c];
std::string l_plus = l+new_char;
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
}
if (last_char == new_char) {
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l];
probs_nb_cur[l] += prob_c * probs_nb_prev[l];
} else if (new_char == " ") {
double score = 1.0;
if (ext_scorer != NULL && l.size() > 1) {
score = ext_scorer->get_score(l.substr(1));
}
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l]);
} else {
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]);
}
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus];
}
}
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
}
probs_b_prev = probs_b_cur;
probs_nb_prev = probs_nb_cur;
std::vector<std::pair<std::string, double> >
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
prefix_set_prev = std::map<std::string, double>
(prefix_vec_next.begin(), prefix_vec_next.begin()+k);
}
// post processing
std::vector<std::pair<double, std::string> > beam_result;
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) {
if (it->second > 0.0 && it->first.size() > 1) {
double prob = it->second;
std::string sentence = it->first.substr(1);
// scoring the last word
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
prob = prob * ext_scorer->get_score(sentence);
}
double log_prob = log(it->second);
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first));
}
}
// sort the result and return
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>);
return beam_result;
}
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <vector>
#include <string>
#include <utility>
#include "scorer.h"
std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id=0,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
bool nproc=false
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
%module swig_ctc_beam_search_decoder
%{
#include "ctc_beam_search_decoder.h"
%}
%include "std_vector.i"
%include "std_pair.i"
%include "std_string.i"
namespace std{
%template(DoubleVector) std::vector<double>;
%template(IntVector) std::vector<int>;
%template(StringVector) std::vector<std::string>;
%template(VectorOfStructVector) std::vector<std::vector<double> >;
%template(FloatVector) std::vector<float>;
%template(Pair) std::pair<float, std::string>;
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
}
%import scorer.h
%include "ctc_beam_search_decoder.h"
from setuptools import setup, Extension
import glob
import platform
import os
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
'util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./ctc_beam_search_decoder.i')
ctc_beam_search_decoder_module = [
Extension(
name='_swig_ctc_beam_search_decoder',
sources=FILES + [
'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx',
'ctc_beam_search_decoder.cpp'
],
language='C++',
include_dirs=['.'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_ctc_beam_search_decoder',
version='0.1',
author='Yibing Liu',
description="""CTC beam search decoder""",
ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_beam_search_decoder'], )
#include <iostream>
#include "scorer.h"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha;
this->_beta = beta;
this->_language_model = new Model(lm_model_path.c_str());
}
Scorer::~Scorer(){
delete (Model *)this->_language_model;
}
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
}
for (int i=str.size()-1; i>=0; i--) {
if (str[i] == ch) {
end --;
} else {
break;
}
}
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
}
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 0;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
}
}
if (cnt > 0) cnt ++;
return cnt;
}
double Scorer::language_model_score(std::string sentence) {
Model *model = (Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
state = model->BeginSentenceState();
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex vocab = model->GetVocabulary().Index(*it);
ret = model->FullScore(state, vocab, out_state);
state = out_state;
}
double score = ret.prob;
return pow(10, score);
}
double Scorer::get_score(std::string sentence) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);
double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta);
return final_score;
}
#ifndef SCORER_H_
#define SCORER_H_
#include <string>
class Scorer{
private:
float _alpha;
float _beta;
void *_language_model;
public:
Scorer(){}
Scorer(float alpha, float beta, std::string lm_model_path);
~Scorer();
int word_count(std::string);
double language_model_score(std::string);
double get_score(std::string);
};
#endif
%module swig_scorer
%{
#include "scorer.h"
%}
%include "std_string.i"
%include "scorer.h"
from setuptools import setup, Extension
import glob
import platform
import os
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
'util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./scorer.i')
ext_modules = [
Extension(
name='_swig_scorer',
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
language='C++',
include_dirs=['.'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_scorer',
version='0.1',
ext_modules=ext_modules,
include_package_data=True,
py_modules=['swig_scorer'], )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册