提交 ccea7c01 编写于 作者: Y Yibing Liu

enable loading language model in multiple format

上级 5bfa0669
......@@ -14,6 +14,7 @@ from swig_ctc_beam_search_decoder import *
from swig_scorer import Scorer
from error_rate import wer
import utils
import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -74,7 +75,7 @@ parser.add_argument(
)
parser.add_argument(
"--beam_size",
default=500,
default=200,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
......@@ -166,6 +167,7 @@ def infer():
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
## decode and print
time_begin = time.time()
wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
......@@ -183,6 +185,8 @@ def infer():
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
time_end = time.time()
print("total time = %f" % (time_end - time_begin))
def main():
......
#include <iostream>
#include <unistd.h>
#include "scorer.h"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
......@@ -9,11 +10,16 @@ using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha;
this->_beta = beta;
this->_language_model = new Model(lm_model_path.c_str());
if (access(lm_model_path.c_str(), F_OK) != 0) {
std::cout<<"Invalid language model path!"<<std::endl;
exit(1);
}
this->_language_model = LoadVirtual(lm_model_path.c_str());
}
Scorer::~Scorer(){
delete (Model *)this->_language_model;
delete (lm::base::Model *)this->_language_model;
}
/* Strip a input sentence
......@@ -63,14 +69,14 @@ int Scorer::word_count(std::string sentence) {
}
double Scorer::language_model_score(std::string sentence) {
Model *model = (Model *)this->_language_model;
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
state = model->BeginSentenceState();
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex vocab = model->GetVocabulary().Index(*it);
ret = model->FullScore(state, vocab, out_state);
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state;
}
//log10 prob
......
......@@ -3,9 +3,9 @@ echo "Run decoder setup ..."
python decoder_setup.py install
rm -r ./build
echo "\nRun scorer setup ..."
echo "Run scorer setup ..."
python scorer_setup.py install
rm -r ./build
echo "\nFinish the installation of decoder and scorer."
echo "Finish the installation of decoder and scorer."
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册