path_trie.h 2.3 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
17

18 19 20 21 22 23
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

24
#include "fst/fstlib.h"
25

Y
Yibing Liu 已提交
26 27 28
/* Trie tree for prefix storing and manipulating, with a dictionary in
 * finite-state transducer for spelling correction.
 */
29
class PathTrie {
30 31 32
  public:
    PathTrie();
    ~PathTrie();
33

34 35
    // get new prefix after appending new char
    PathTrie* get_path_trie(int new_char, bool reset = true);
36

37 38
    // get the prefix in index from root to current node
    PathTrie* get_path_vec(std::vector<int>& output);
39

40 41 42 43 44
    // get the prefix in index from some stop node to current nodel
    PathTrie* get_path_vec(
        std::vector<int>& output,
        int stop,
        size_t max_steps = std::numeric_limits<size_t>::max());
45

46 47
    // update log probs
    void iterate_to_vec(std::vector<PathTrie*>& output);
48

49 50
    // set dictionary for FST
    void set_dictionary(fst::StdVectorFst* dictionary);
51

52
    void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
53

54
    bool is_empty() { return ROOT_ == character; }
55

56 57
    // remove current path from root
    void remove();
58

59 60 61 62 63 64 65 66
    float log_prob_b_prev;
    float log_prob_nb_prev;
    float log_prob_b_cur;
    float log_prob_nb_cur;
    float score;
    float approx_ctc;
    int character;
    PathTrie* parent;
67

68 69 70 71
  private:
    int ROOT_;
    bool exists_;
    bool has_dictionary_;
72

73
    std::vector<std::pair<int, PathTrie*>> children_;
74

75 76 77 78 79
    // pointer to dictionary of FST
    fst::StdVectorFst* dictionary_;
    fst::StdVectorFst::StateId dictionary_state_;
    // true if finding ars in FST
    std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
80 81
};

Y
Yibing Liu 已提交
82
#endif  // PATH_TRIE_H