decoder_utils.h 2.6 KB
Newer Older
1 2 3
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_

Y
Yibing Liu 已提交
4
#include <utility>
5
#include "path_trie.h"
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7 8 9
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();

Y
Yibing Liu 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// check if __A == _B
#define VALID_CHECK_EQ(__A, __B, __ERR)          \
  if ((__A) != (__B)) {                          \
    std::ostringstream str;                      \
    str << (__A) << " != " << (__B) << ", ";     \
    throw std::runtime_error(str.str() + __ERR); \
  }

// check if __A > __B
#define VALID_CHECK_GT(__A, __B, __ERR)          \
  if ((__A) <= (__B)) {                          \
    std::ostringstream str;                      \
    str << (__A) << " <= " << (__B) << ", ";     \
    throw std::runtime_error(str.str() + __ERR); \
  }

26
// Function template for comparing two pairs
Y
Yibing Liu 已提交
27
template <typename T1, typename T2>
Y
Yibing Liu 已提交
28
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
Y
Yibing Liu 已提交
29 30
                         const std::pair<T1, T2> &b) {
  return a.first > b.first;
31
}
Y
Yibing Liu 已提交
32

Y
Yibing Liu 已提交
33
// Function template for comparing two pairs
Y
Yibing Liu 已提交
34
template <typename T1, typename T2>
Y
Yibing Liu 已提交
35
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
Y
Yibing Liu 已提交
36 37
                          const std::pair<T1, T2> &b) {
  return a.second > b.second;
38 39
}

Y
Yibing Liu 已提交
40
// Return the sum of two probabilities in log scale
41
template <typename T>
Y
Yibing Liu 已提交
42 43 44 45 46 47
T log_sum_exp(const T &x, const T &y) {
  static T num_min = -std::numeric_limits<T>::max();
  if (x <= num_min) return y;
  if (y <= num_min) return x;
  T xmax = std::max(x, y);
  return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
48 49
}

Y
Yibing Liu 已提交
50
// Functor for prefix comparsion
Y
Yibing Liu 已提交
51
bool prefix_compare(const PathTrie *x, const PathTrie *y);
52

Y
Yibing Liu 已提交
53 54 55
/* Get length of utf8 encoding string
 * See: http://stackoverflow.com/a/4063229
 */
Y
Yibing Liu 已提交
56
size_t get_utf8_str_len(const std::string &str);
Y
Yibing Liu 已提交
57

Y
Yibing Liu 已提交
58 59 60 61
/* Split a string into a list of strings on a given string
 * delimiter. NB: delimiters on beginning / end of string are
 * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
 */
Y
Yibing Liu 已提交
62 63 64
std::vector<std::string> split_str(const std::string &s,
                                   const std::string &delim);

Y
Yibing Liu 已提交
65 66 67
/* Splits string into vector of strings representing
 * UTF-8 characters (not same as chars)
 */
Y
Yibing Liu 已提交
68
std::vector<std::string> split_utf8_str(const std::string &str);
69

70
// Add a word in index to the dicionary of fst
Y
Yibing Liu 已提交
71 72
void add_word_to_fst(const std::vector<int> &word,
                     fst::StdVectorFst *dictionary);
73

74
// Add a word in string to dictionary
Y
Yibing Liu 已提交
75 76 77 78 79 80 81
bool add_word_to_dictionary(
    const std::string &word,
    const std::unordered_map<std::string, int> &char_map,
    bool add_space,
    int SPACE_ID,
    fst::StdVectorFst *dictionary);
#endif  // DECODER_UTILS_H