path_trie.cpp 3.9 KB
Newer Older
1 2
#include "path_trie.h"

3 4 5 6 7 8 9 10 11
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

#include "decoder_utils.h"

PathTrie::PathTrie() {
Y
Yibing Liu 已提交
12 13 14 15 16 17
  log_prob_b_prev = -NUM_FLT_INF;
  log_prob_nb_prev = -NUM_FLT_INF;
  log_prob_b_cur = -NUM_FLT_INF;
  log_prob_nb_cur = -NUM_FLT_INF;
  score = -NUM_FLT_INF;

18 19 20
  ROOT_ = -1;
  character = ROOT_;
  exists_ = true;
Y
Yibing Liu 已提交
21
  parent = nullptr;
Y
Yibing Liu 已提交
22

23 24 25
  dictionary_ = nullptr;
  dictionary_state_ = 0;
  has_dictionary_ = false;
Y
Yibing Liu 已提交
26

27
  matcher_ = nullptr;
28 29 30
}

PathTrie::~PathTrie() {
31
  for (auto child : children_) {
Y
Yibing Liu 已提交
32 33
    delete child.second;
  }
34 35 36
}

PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
37 38
  auto child = children_.begin();
  for (child = children_.begin(); child != children_.end(); ++child) {
Y
Yibing Liu 已提交
39 40
    if (child->first == new_char) {
      break;
41
    }
Y
Yibing Liu 已提交
42
  }
43 44 45
  if (child != children_.end()) {
    if (!child->second->exists_) {
      child->second->exists_ = true;
Y
Yibing Liu 已提交
46 47 48 49 50 51 52
      child->second->log_prob_b_prev = -NUM_FLT_INF;
      child->second->log_prob_nb_prev = -NUM_FLT_INF;
      child->second->log_prob_b_cur = -NUM_FLT_INF;
      child->second->log_prob_nb_cur = -NUM_FLT_INF;
    }
    return (child->second);
  } else {
53 54 55
    if (has_dictionary_) {
      matcher_->SetState(dictionary_state_);
      bool found = matcher_->Find(new_char);
Y
Yibing Liu 已提交
56 57 58
      if (!found) {
        // Adding this character causes word outside dictionary
        auto FSTZERO = fst::TropicalWeight::Zero();
59
        auto final_weight = dictionary_->Final(dictionary_state_);
Y
Yibing Liu 已提交
60 61
        bool is_final = (final_weight != FSTZERO);
        if (is_final && reset) {
62
          dictionary_state_ = dictionary_->Start();
63
        }
Y
Yibing Liu 已提交
64 65 66 67 68
        return nullptr;
      } else {
        PathTrie* new_path = new PathTrie;
        new_path->character = new_char;
        new_path->parent = this;
69 70 71 72 73
        new_path->dictionary_ = dictionary_;
        new_path->dictionary_state_ = matcher_->Value().nextstate;
        new_path->has_dictionary_ = true;
        new_path->matcher_ = matcher_;
        children_.push_back(std::make_pair(new_char, new_path));
Y
Yibing Liu 已提交
74 75
        return new_path;
      }
76
    } else {
Y
Yibing Liu 已提交
77 78 79
      PathTrie* new_path = new PathTrie;
      new_path->character = new_char;
      new_path->parent = this;
80
      children_.push_back(std::make_pair(new_char, new_path));
Y
Yibing Liu 已提交
81
      return new_path;
82
    }
Y
Yibing Liu 已提交
83
  }
84 85 86
}

PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
87
  return get_path_vec(output, ROOT_);
88 89 90
}

PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
Y
Yibing Liu 已提交
91 92
                                 int stop,
                                 size_t max_steps) {
93
  if (character == stop || character == ROOT_ || output.size() == max_steps) {
Y
Yibing Liu 已提交
94 95 96 97 98 99
    std::reverse(output.begin(), output.end());
    return this;
  } else {
    output.push_back(character);
    return parent->get_path_vec(output, stop, max_steps);
  }
100 101
}

Y
Yibing Liu 已提交
102
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
103
  if (exists_) {
Y
Yibing Liu 已提交
104 105
    log_prob_b_prev = log_prob_b_cur;
    log_prob_nb_prev = log_prob_nb_cur;
106

Y
Yibing Liu 已提交
107 108
    log_prob_b_cur = -NUM_FLT_INF;
    log_prob_nb_cur = -NUM_FLT_INF;
109

Y
Yibing Liu 已提交
110 111 112
    score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
    output.push_back(this);
  }
113
  for (auto child : children_) {
Y
Yibing Liu 已提交
114 115
    child.second->iterate_to_vec(output);
  }
116 117 118
}

void PathTrie::remove() {
119
  exists_ = false;
Y
Yibing Liu 已提交
120

121 122 123
  if (children_.size() == 0) {
    auto child = parent->children_.begin();
    for (child = parent->children_.begin(); child != parent->children_.end();
Y
Yibing Liu 已提交
124 125
         ++child) {
      if (child->first == character) {
126
        parent->children_.erase(child);
Y
Yibing Liu 已提交
127 128 129
        break;
      }
    }
130

131
    if (parent->children_.size() == 0 && !parent->exists_) {
Y
Yibing Liu 已提交
132
      parent->remove();
133
    }
Y
Yibing Liu 已提交
134 135 136

    delete this;
  }
137 138 139
}

void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
140 141 142
  dictionary_ = dictionary;
  dictionary_state_ = dictionary->Start();
  has_dictionary_ = true;
143 144 145 146
}

using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
147
  matcher_ = matcher;
148
}