scorer.h 2.7 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3
#ifndef SCORER_H_
#define SCORER_H_

4
#include <memory>
Y
Yibing Liu 已提交
5
#include <string>
6
#include <unordered_map>
Y
Yibing Liu 已提交
7
#include <vector>
8

9 10
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
Y
Yibing Liu 已提交
11 12
#include "lm/word_index.hh"
#include "util/string_piece.hh"
Y
Yibing Liu 已提交
13

14 15
#include "path_trie.h"

Y
Yibing Liu 已提交
16
const double OOV_SCORE = -1000.0;
17 18 19
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
Y
Yibing Liu 已提交
20

Y
Yibing Liu 已提交
21
// Implement a callback to retrive string vocabulary.
22 23
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
Y
Yibing Liu 已提交
24
  RetriveStrEnumerateVocab() {}
Y
Yibing Liu 已提交
25

Y
Yibing Liu 已提交
26
  void Add(lm::WordIndex index, const StringPiece &str) {
Y
Yibing Liu 已提交
27 28
    vocabulary.push_back(std::string(str.data(), str.length()));
  }
29

Y
Yibing Liu 已提交
30
  std::vector<std::string> vocabulary;
31
};
32

Y
Yibing Liu 已提交
33 34
/* External scorer to query score for n-gram or sentence, including language
 * model scoring and word insertion.
35 36 37 38 39 40
 *
 * Example:
 *     Scorer scorer(alpha, beta, "path_of_language_model");
 *     scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
 *     scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
 */
Y
Yibing Liu 已提交
41
class Scorer {
Y
Yibing Liu 已提交
42
public:
43 44 45 46
  Scorer(double alpha,
         double beta,
         const std::string &lm_path,
         const std::vector<std::string> &vocabulary);
Y
Yibing Liu 已提交
47
  ~Scorer();
48

Y
Yibing Liu 已提交
49
  double get_log_cond_prob(const std::vector<std::string> &words);
50

Y
Yibing Liu 已提交
51
  double get_sent_log_prob(const std::vector<std::string> &words);
52

53
  size_t get_max_order() const { return _max_order; }
54

55
  size_t get_dict_size() const { return _dict_size; }
56

57 58 59
  bool is_char_map_empty() const { return _char_map.size() == 0; }

  bool is_character_based() const { return _is_character_based; }
60

Y
Yibing Liu 已提交
61 62
  // reset params alpha & beta
  void reset_params(float alpha, float beta);
63

64
  // make ngram for a given prefix
Y
Yibing Liu 已提交
65
  std::vector<std::string> make_ngram(PathTrie *prefix);
66

67 68
  // trransform the labels in index to the vector of words (word based lm) or
  // the vector of characters (character based lm)
Y
Yibing Liu 已提交
69
  std::vector<std::string> split_labels(const std::vector<int> &labels);
70

Y
Yibing Liu 已提交
71 72 73
  // expose to decoder
  double alpha;
  double beta;
74

Y
Yibing Liu 已提交
75
  // fst dictionary
Y
Yibing Liu 已提交
76
  void *dictionary;
77

78
protected:
79 80 81 82 83 84 85 86 87 88
  void setup(const std::string &lm_path,
             const std::vector<std::string> &vocab_list);

  void load_lm(const std::string &lm_path);

  // fill dictionary for fst
  void fill_dictionary(bool add_space);

  // set char map
  void set_char_map(const std::vector<std::string> &char_list);
89

Y
Yibing Liu 已提交
90
  double get_log_prob(const std::vector<std::string> &words);
91

Y
Yibing Liu 已提交
92
  std::string vec2str(const std::vector<int> &input);
93 94

private:
Y
Yibing Liu 已提交
95
  void *_language_model;
Y
Yibing Liu 已提交
96 97
  bool _is_character_based;
  size_t _max_order;
98
  size_t _dict_size;
99

Y
Yibing Liu 已提交
100 101 102
  int _SPACE_ID;
  std::vector<std::string> _char_list;
  std::unordered_map<char, int> _char_map;
103

Y
Yibing Liu 已提交
104
  std::vector<std::string> _vocabulary;
Y
Yibing Liu 已提交
105 106
};

Y
Yibing Liu 已提交
107
#endif  // SCORER_H_