diff --git a/deploy.py b/deploy.py index d8a7e5b277154deec715aa1fd89a62cb5898f693..02152b499fc322f9bc8e82315142117da77fec3a 100644 --- a/deploy.py +++ b/deploy.py @@ -14,6 +14,7 @@ from swig_ctc_beam_search_decoder import * from swig_scorer import Scorer from error_rate import wer import utils +import time parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -74,7 +75,7 @@ parser.add_argument( ) parser.add_argument( "--beam_size", - default=500, + default=200, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( @@ -166,6 +167,7 @@ def infer(): ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ## decode and print + time_begin = time.time() wer_sum, wer_counter = 0, 0 for i, probs in enumerate(probs_split): beam_result = ctc_beam_search_decoder( @@ -183,6 +185,8 @@ def infer(): wer_counter += 1 print("cur wer = %f , average wer = %f" % (wer_cur, wer_sum / wer_counter)) + time_end = time.time() + print("total time = %f" % (time_end - time_begin)) def main(): diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index 1b843402bf6f59b1cadf21f2fdfa1f8b1e2d1b04..d438ec1bda7ebccafc15ba21d156b67c2fa21e57 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -1,4 +1,5 @@ #include +#include #include "scorer.h" #include "lm/model.hh" #include "util/tokenize_piece.hh" @@ -9,11 +10,16 @@ using namespace lm::ngram; Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { this->_alpha = alpha; this->_beta = beta; - this->_language_model = new Model(lm_model_path.c_str()); + + if (access(lm_model_path.c_str(), F_OK) != 0) { + std::cout<<"Invalid language model path!"<_language_model = LoadVirtual(lm_model_path.c_str()); } Scorer::~Scorer(){ - delete (Model *)this->_language_model; + delete (lm::base::Model *)this->_language_model; } /* Strip a input sentence @@ -63,14 +69,14 @@ int Scorer::word_count(std::string sentence) { } double Scorer::language_model_score(std::string sentence) { - Model *model = (Model *)this->_language_model; + lm::base::Model *model = (lm::base::Model *)this->_language_model; State state, out_state; lm::FullScoreReturn ret; - state = model->BeginSentenceState(); + model->BeginSentenceWrite(&state); for (util::TokenIter it(sentence, ' '); it; ++it){ - lm::WordIndex vocab = model->GetVocabulary().Index(*it); - ret = model->FullScore(state, vocab, out_state); + lm::WordIndex wid = model->BaseVocabulary().Index(*it); + ret = model->BaseFullScore(&state, wid, &out_state); state = out_state; } //log10 prob diff --git a/deploy/setup.sh b/deploy/setup.sh index e84cd9235d5bda615016b1322a65fc037f2271f4..423f5b8922c89e785231bd84557de0b59a089f53 100644 --- a/deploy/setup.sh +++ b/deploy/setup.sh @@ -3,9 +3,9 @@ echo "Run decoder setup ..." python decoder_setup.py install rm -r ./build -echo "\nRun scorer setup ..." +echo "Run scorer setup ..." python scorer_setup.py install rm -r ./build -echo "\nFinish the installation of decoder and scorer." +echo "Finish the installation of decoder and scorer."