decoder_utils.cpp 4.0 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <limits>
#include <algorithm>
#include <cmath>
#include "decoder_utils.h"

6 7 8 9 10 11 12
size_t get_utf8_str_len(const std::string& str) {
    size_t str_len = 0;
    for (char c : str) {
        str_len += ((c & 0xc0) != 0x80);
    }
    return str_len;
}
13

14 15 16 17
//------------------------------------------------------
//Splits string into vector of strings representing
//UTF-8 characters (not same as chars)
//------------------------------------------------------
Y
Yibing Liu 已提交
18
std::vector<std::string> split_utf8_str(const std::string& str)
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
{
  std::vector<std::string> result;
  std::string out_str;

  for (char c : str)
    {
      if ((c & 0xc0) != 0x80)  //new UTF-8 character
        {
          if (!out_str.empty())
            {
              result.push_back(out_str);
              out_str.clear();
            }
        }

      out_str.append(1, c);
    }
  result.push_back(out_str);
  return result;
}

Y
Yibing Liu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> split_str(const std::string &s,
                                   const std::string &delim) {
    std::vector<std::string> result;
    std::size_t start = 0, delim_len = delim.size();
    while (true) {
        std::size_t end = s.find(delim, start);
        if (end == std::string::npos) {
            if (start < s.size()) {
                result.push_back(s.substr(start));
            }
            break;
        }
        if (end > start) {
            result.push_back(s.substr(start, end - start));
        }
        start = end + delim_len;
    }
    return result;
}

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
//-------------------------------------------------------
//  Overriding less than operator for sorting
//-------------------------------------------------------
bool prefix_compare(const PathTrie* x,  const PathTrie* y) {
    if (x->_score == y->_score) {
        if (x->_character == y->_character) {
            return false;
        } else {
            return (x->_character < y->_character);
        }
    } else {
        return x->_score > y->_score;
    }
}  //---------- End path_compare ---------------------------

// --------------------------------------------------------------
// Adds word to fst without copying entire dictionary
// --------------------------------------------------------------
void add_word_to_fst(const std::vector<int>& word,
                     fst::StdVectorFst* dictionary) {
    if (dictionary->NumStates() == 0) {
        fst::StdVectorFst::StateId start = dictionary->AddState();
        assert(start == 0);
        dictionary->SetStart(start);
    }
    fst::StdVectorFst::StateId src = dictionary->Start();
    fst::StdVectorFst::StateId dst;
    for (auto c : word) {
        dst = dictionary->AddState();
        dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
        src = dst;
    }
    dictionary->SetFinal(dst, fst::StdArc::Weight::One());
}  // ------------ End of add_word_to_fst -----------------------

// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
101
bool add_word_to_dictionary(const std::string& word,
102 103 104 105
                         const std::unordered_map<std::string, int>& char_map,
                         bool add_space,
                         int SPACE,
                         fst::StdVectorFst* dictionary) {
Y
Yibing Liu 已提交
106
    auto characters = split_utf8_str(word);
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

    std::vector<int> int_word;

    for (auto& c : characters) {
        if (c == " ") {
            int_word.push_back(SPACE);
        } else {
            auto int_c = char_map.find(c);
            if (int_c != char_map.end()) {
                int_word.push_back(int_c->second);
            } else {
                return false;  // return without adding
            }
        }
    }

    if (add_space) {
        int_word.push_back(SPACE);
    }

    add_word_to_fst(int_word, dictionary);
    return true;
}  // -------------- End of addWordToDictionary ------------