path_trie.cpp 3.9 KB
Newer Older
1 2
#include "path_trie.h"

3 4 5 6 7 8 9 10 11
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

#include "decoder_utils.h"

PathTrie::PathTrie() {
Y
Yibing Liu 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24
  log_prob_b_prev = -NUM_FLT_INF;
  log_prob_nb_prev = -NUM_FLT_INF;
  log_prob_b_cur = -NUM_FLT_INF;
  log_prob_nb_cur = -NUM_FLT_INF;
  score = -NUM_FLT_INF;

  _ROOT = -1;
  character = _ROOT;
  _exists = true;
  parent = nullptr;
  _dictionary = nullptr;
  _dictionary_state = 0;
  _has_dictionary = false;
Y
Yibing Liu 已提交
25
  _matcher = nullptr;
26 27 28
}

PathTrie::~PathTrie() {
Y
Yibing Liu 已提交
29 30 31
  for (auto child : _children) {
    delete child.second;
  }
32 33 34
}

PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
Y
Yibing Liu 已提交
35 36 37 38
  auto child = _children.begin();
  for (child = _children.begin(); child != _children.end(); ++child) {
    if (child->first == new_char) {
      break;
39
    }
Y
Yibing Liu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
  }
  if (child != _children.end()) {
    if (!child->second->_exists) {
      child->second->_exists = true;
      child->second->log_prob_b_prev = -NUM_FLT_INF;
      child->second->log_prob_nb_prev = -NUM_FLT_INF;
      child->second->log_prob_b_cur = -NUM_FLT_INF;
      child->second->log_prob_nb_cur = -NUM_FLT_INF;
    }
    return (child->second);
  } else {
    if (_has_dictionary) {
      _matcher->SetState(_dictionary_state);
      bool found = _matcher->Find(new_char);
      if (!found) {
        // Adding this character causes word outside dictionary
        auto FSTZERO = fst::TropicalWeight::Zero();
        auto final_weight = _dictionary->Final(_dictionary_state);
        bool is_final = (final_weight != FSTZERO);
        if (is_final && reset) {
          _dictionary_state = _dictionary->Start();
61
        }
Y
Yibing Liu 已提交
62 63 64 65 66 67 68 69 70 71 72 73
        return nullptr;
      } else {
        PathTrie* new_path = new PathTrie;
        new_path->character = new_char;
        new_path->parent = this;
        new_path->_dictionary = _dictionary;
        new_path->_dictionary_state = _matcher->Value().nextstate;
        new_path->_has_dictionary = true;
        new_path->_matcher = _matcher;
        _children.push_back(std::make_pair(new_char, new_path));
        return new_path;
      }
74
    } else {
Y
Yibing Liu 已提交
75 76 77 78 79
      PathTrie* new_path = new PathTrie;
      new_path->character = new_char;
      new_path->parent = this;
      _children.push_back(std::make_pair(new_char, new_path));
      return new_path;
80
    }
Y
Yibing Liu 已提交
81
  }
82 83 84
}

PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
Y
Yibing Liu 已提交
85
  return get_path_vec(output, _ROOT);
86 87 88
}

PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
Y
Yibing Liu 已提交
89 90 91 92 93 94 95 96 97
                                 int stop,
                                 size_t max_steps) {
  if (character == stop || character == _ROOT || output.size() == max_steps) {
    std::reverse(output.begin(), output.end());
    return this;
  } else {
    output.push_back(character);
    return parent->get_path_vec(output, stop, max_steps);
  }
98 99
}

Y
Yibing Liu 已提交
100 101 102 103
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
  if (_exists) {
    log_prob_b_prev = log_prob_b_cur;
    log_prob_nb_prev = log_prob_nb_cur;
104

Y
Yibing Liu 已提交
105 106
    log_prob_b_cur = -NUM_FLT_INF;
    log_prob_nb_cur = -NUM_FLT_INF;
107

Y
Yibing Liu 已提交
108 109 110 111 112 113
    score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
    output.push_back(this);
  }
  for (auto child : _children) {
    child.second->iterate_to_vec(output);
  }
114 115 116
}

void PathTrie::remove() {
Y
Yibing Liu 已提交
117 118 119 120 121 122 123 124 125 126 127
  _exists = false;

  if (_children.size() == 0) {
    auto child = parent->_children.begin();
    for (child = parent->_children.begin(); child != parent->_children.end();
         ++child) {
      if (child->first == character) {
        parent->_children.erase(child);
        break;
      }
    }
128

Y
Yibing Liu 已提交
129 130
    if (parent->_children.size() == 0 && !parent->_exists) {
      parent->remove();
131
    }
Y
Yibing Liu 已提交
132 133 134

    delete this;
  }
135 136 137
}

void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
Y
Yibing Liu 已提交
138 139 140
  _dictionary = dictionary;
  _dictionary_state = dictionary->Start();
  _has_dictionary = true;
141 142 143 144
}

using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
Y
Yibing Liu 已提交
145
  _matcher = matcher;
146
}