scorer.cpp 2.8 KB
Newer Older
Y
Yibing Liu 已提交
1
#include <iostream>
2
#include <unistd.h>
Y
Yibing Liu 已提交
3
#include "scorer.h"
4
#include "decoder_utils.h"
Y
Yibing Liu 已提交
5

6 7 8 9 10 11 12 13
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
    this->alpha = alpha;
    this->beta = beta;
    _is_character_based = true;
    _language_model = nullptr;
    _max_order = 0;
    // load language model
    load_LM(lm_path.c_str());
Y
Yibing Liu 已提交
14 15
}

16 17 18
Scorer::~Scorer() {
    if (_language_model != nullptr)
        delete static_cast<lm::base::Model*>(_language_model);
Y
Yibing Liu 已提交
19 20
}

21 22 23 24
void Scorer::load_LM(const char* filename) {
    if (access(filename, F_OK) != 0) {
        std::cerr << "Invalid language model file !!!" << std::endl;
        exit(1);
Y
Yibing Liu 已提交
25
    }
26 27 28 29 30 31 32 33 34 35 36 37 38
    RetriveStrEnumerateVocab enumerate;
    Config config;
    config.enumerate_vocab = &enumerate;
    _language_model = lm::ngram::LoadVirtual(filename, config);
    _max_order = static_cast<lm::base::Model*>(_language_model)->Order();
    _vocabulary = enumerate.vocabulary;
    for (size_t i = 0; i < _vocabulary.size(); ++i) {
        if (_is_character_based
            && _vocabulary[i] != UNK_TOKEN
            && _vocabulary[i] != START_TOKEN
            && _vocabulary[i] != END_TOKEN
            && get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
                _is_character_based = false;
Y
Yibing Liu 已提交
39 40 41 42
        }
    }
}

43 44 45 46 47 48 49 50 51 52 53
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
    lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
    double cond_prob;
    State state, tmp_state, out_state;
    // avoid to inserting <s> in begin
    model->NullContextWrite(&state);
    for (size_t i = 0; i < words.size(); ++i) {
        lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
        // encounter OOV
        if (word_index == 0) {
            return OOV_SCOER;
Y
Yibing Liu 已提交
54
        }
55 56
        cond_prob = model->BaseScore(&state, word_index, &out_state);
        tmp_state = state;
Y
Yibing Liu 已提交
57
        state = out_state;
58
        out_state = tmp_state;
Y
Yibing Liu 已提交
59
    }
60 61
    // log10 prob
    return cond_prob;
Y
Yibing Liu 已提交
62 63
}

64 65 66 67 68 69 70 71 72 73 74 75 76 77
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
    std::vector<std::string> sentence;
    if (words.size() == 0) {
        for (size_t i = 0; i < _max_order; ++i) {
            sentence.push_back(START_TOKEN);
        }
    } else {
        for (size_t i = 0; i < _max_order - 1; ++i) {
            sentence.push_back(START_TOKEN);
        }
        sentence.insert(sentence.end(), words.begin(), words.end());
    }
    sentence.push_back(END_TOKEN);
    return get_log_prob(sentence);
78 79
}

80 81 82 83 84 85 86
double Scorer::get_log_prob(const std::vector<std::string>& words) {
    assert(words.size() > _max_order);
    double score = 0.0;
    for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
        std::vector<std::string> ngram(words.begin() + i,
                                       words.begin() + i + _max_order);
        score += get_log_cond_prob(ngram);
87
    }
88
    return score;
Y
Yibing Liu 已提交
89
}