scorer.h 2.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3
#ifndef SCORER_H_
#define SCORER_H_

4
#include <memory>
Y
Yibing Liu 已提交
5
#include <string>
6
#include <unordered_map>
Y
Yibing Liu 已提交
7
#include <vector>
8 9
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
Y
Yibing Liu 已提交
10
#include "lm/word_index.hh"
11
#include "path_trie.h"
Y
Yibing Liu 已提交
12
#include "util/string_piece.hh"
Y
Yibing Liu 已提交
13

Y
Yibing Liu 已提交
14
const double OOV_SCORE = -1000.0;
15 16 17
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
Y
Yibing Liu 已提交
18

Y
Yibing Liu 已提交
19
// Implement a callback to retrive string vocabulary.
20 21
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
Y
Yibing Liu 已提交
22
  RetriveStrEnumerateVocab() {}
Y
Yibing Liu 已提交
23

Y
Yibing Liu 已提交
24 25 26
  void Add(lm::WordIndex index, const StringPiece& str) {
    vocabulary.push_back(std::string(str.data(), str.length()));
  }
27

Y
Yibing Liu 已提交
28
  std::vector<std::string> vocabulary;
29
};
30

31 32 33 34 35
// External scorer to query languange score for n-gram or sentence.
// Example:
//     Scorer scorer(alpha, beta, "path_of_language_model");
//     scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
//     scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
Y
Yibing Liu 已提交
36
class Scorer {
Y
Yibing Liu 已提交
37
public:
Y
Yibing Liu 已提交
38 39
  Scorer(double alpha, double beta, const std::string& lm_path);
  ~Scorer();
40

Y
Yibing Liu 已提交
41
  double get_log_cond_prob(const std::vector<std::string>& words);
42

Y
Yibing Liu 已提交
43
  double get_sent_log_prob(const std::vector<std::string>& words);
44

Y
Yibing Liu 已提交
45
  size_t get_max_order() { return _max_order; }
46

Y
Yibing Liu 已提交
47
  bool is_char_map_empty() { return _char_map.size() == 0; }
48

Y
Yibing Liu 已提交
49
  bool is_character_based() { return _is_character_based; }
50

Y
Yibing Liu 已提交
51 52
  // reset params alpha & beta
  void reset_params(float alpha, float beta);
53

Y
Yibing Liu 已提交
54 55
  // make ngram
  std::vector<std::string> make_ngram(PathTrie* prefix);
56

Y
Yibing Liu 已提交
57 58
  // fill dictionary for fst
  void fill_dictionary(bool add_space);
59

Y
Yibing Liu 已提交
60 61
  // set char map
  void set_char_map(std::vector<std::string> char_list);
62

Y
Yibing Liu 已提交
63
  std::vector<std::string> split_labels(const std::vector<int>& labels);
64

Y
Yibing Liu 已提交
65 66 67
  // expose to decoder
  double alpha;
  double beta;
68

Y
Yibing Liu 已提交
69 70
  // fst dictionary
  void* dictionary;
71

72
protected:
Y
Yibing Liu 已提交
73
  void load_LM(const char* filename);
74

Y
Yibing Liu 已提交
75
  double get_log_prob(const std::vector<std::string>& words);
76

Y
Yibing Liu 已提交
77
  std::string vec2str(const std::vector<int>& input);
78 79

private:
Y
Yibing Liu 已提交
80 81 82
  void* _language_model;
  bool _is_character_based;
  size_t _max_order;
83

Y
Yibing Liu 已提交
84 85 86
  int _SPACE_ID;
  std::vector<std::string> _char_list;
  std::unordered_map<char, int> _char_map;
87

Y
Yibing Liu 已提交
88
  std::vector<std::string> _vocabulary;
Y
Yibing Liu 已提交
89 90
};

Y
Yibing Liu 已提交
91
#endif  // SCORER_H_