decoder_utils.cpp 4.9 KB
Newer Older
Y
Yibing Liu 已提交
1
#include "decoder_utils.h"
2

Y
Yibing Liu 已提交
3 4
#include <algorithm>
#include <cmath>
Y
Yibing Liu 已提交
5
#include <limits>
Y
Yibing Liu 已提交
6

7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
    const std::vector<double> &prob_step,
    double cutoff_prob,
    size_t cutoff_top_n) {
  std::vector<std::pair<int, double>> prob_idx;
  for (size_t i = 0; i < prob_step.size(); ++i) {
    prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
  }
  // pruning of vacobulary
  size_t cutoff_len = prob_step.size();
  if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
    std::sort(
        prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
    if (cutoff_prob < 1.0) {
      double cum_prob = 0.0;
      cutoff_len = 0;
      for (size_t i = 0; i < prob_idx.size(); ++i) {
        cum_prob += prob_idx[i].second;
        cutoff_len += 1;
        if (cum_prob >= cutoff_prob) break;
      }
    }
    cutoff_len = std::min(cutoff_len, cutoff_top_n);
    prob_idx = std::vector<std::pair<int, double>>(
        prob_idx.begin(), prob_idx.begin() + cutoff_len);
  }
  std::vector<std::pair<size_t, float>> log_prob_idx;
  for (size_t i = 0; i < cutoff_len; ++i) {
    log_prob_idx.push_back(std::pair<int, float>(
        prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
  }
  return log_prob_idx;
}


std::vector<std::pair<double, std::string>> get_beam_search_result(
    const std::vector<PathTrie *> &prefixes,
    const std::vector<std::string> &vocabulary,
    size_t beam_size) {
  // allow for the post processing
  std::vector<PathTrie *> space_prefixes;
  if (space_prefixes.empty()) {
    for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
      space_prefixes.push_back(prefixes[i]);
    }
  }

  std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
  std::vector<std::pair<double, std::string>> output_vecs;
  for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
    std::vector<int> output;
    space_prefixes[i]->get_path_vec(output);
    // convert index to string
    std::string output_str;
    for (size_t j = 0; j < output.size(); j++) {
      output_str += vocabulary[output[j]];
    }
    std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
                                               output_str);
    output_vecs.emplace_back(output_pair);
  }

  return output_vecs;
}

Y
Yibing Liu 已提交
72
size_t get_utf8_str_len(const std::string &str) {
Y
Yibing Liu 已提交
73 74 75 76 77
  size_t str_len = 0;
  for (char c : str) {
    str_len += ((c & 0xc0) != 0x80);
  }
  return str_len;
78
}
79

Y
Yibing Liu 已提交
80
std::vector<std::string> split_utf8_str(const std::string &str) {
81 82 83
  std::vector<std::string> result;
  std::string out_str;

Y
Yibing Liu 已提交
84 85
  for (char c : str) {
    if ((c & 0xc0) != 0x80)  // new UTF-8 character
86
    {
Y
Yibing Liu 已提交
87 88 89 90
      if (!out_str.empty()) {
        result.push_back(out_str);
        out_str.clear();
      }
91
    }
Y
Yibing Liu 已提交
92 93 94

    out_str.append(1, c);
  }
95 96 97 98
  result.push_back(out_str);
  return result;
}

Y
Yibing Liu 已提交
99 100
std::vector<std::string> split_str(const std::string &s,
                                   const std::string &delim) {
Y
Yibing Liu 已提交
101 102 103 104 105 106 107 108 109 110 111 112
  std::vector<std::string> result;
  std::size_t start = 0, delim_len = delim.size();
  while (true) {
    std::size_t end = s.find(delim, start);
    if (end == std::string::npos) {
      if (start < s.size()) {
        result.push_back(s.substr(start));
      }
      break;
    }
    if (end > start) {
      result.push_back(s.substr(start, end - start));
Y
Yibing Liu 已提交
113
    }
Y
Yibing Liu 已提交
114 115 116
    start = end + delim_len;
  }
  return result;
Y
Yibing Liu 已提交
117 118
}

Y
Yibing Liu 已提交
119
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
Y
Yibing Liu 已提交
120 121 122
  if (x->score == y->score) {
    if (x->character == y->character) {
      return false;
123
    } else {
Y
Yibing Liu 已提交
124
      return (x->character < y->character);
125
    }
Y
Yibing Liu 已提交
126 127 128
  } else {
    return x->score > y->score;
  }
129
}
130

Y
Yibing Liu 已提交
131 132
void add_word_to_fst(const std::vector<int> &word,
                     fst::StdVectorFst *dictionary) {
Y
Yibing Liu 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145
  if (dictionary->NumStates() == 0) {
    fst::StdVectorFst::StateId start = dictionary->AddState();
    assert(start == 0);
    dictionary->SetStart(start);
  }
  fst::StdVectorFst::StateId src = dictionary->Start();
  fst::StdVectorFst::StateId dst;
  for (auto c : word) {
    dst = dictionary->AddState();
    dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
    src = dst;
  }
  dictionary->SetFinal(dst, fst::StdArc::Weight::One());
146
}
147

Y
Yibing Liu 已提交
148
bool add_word_to_dictionary(
Y
Yibing Liu 已提交
149 150
    const std::string &word,
    const std::unordered_map<std::string, int> &char_map,
Y
Yibing Liu 已提交
151 152
    bool add_space,
    int SPACE_ID,
Y
Yibing Liu 已提交
153
    fst::StdVectorFst *dictionary) {
Y
Yibing Liu 已提交
154
  auto characters = split_utf8_str(word);
155

Y
Yibing Liu 已提交
156
  std::vector<int> int_word;
157

Y
Yibing Liu 已提交
158
  for (auto &c : characters) {
Y
Yibing Liu 已提交
159 160 161 162 163 164 165 166 167
    if (c == " ") {
      int_word.push_back(SPACE_ID);
    } else {
      auto int_c = char_map.find(c);
      if (int_c != char_map.end()) {
        int_word.push_back(int_c->second);
      } else {
        return false;  // return without adding
      }
168
    }
Y
Yibing Liu 已提交
169
  }
170

Y
Yibing Liu 已提交
171 172 173
  if (add_space) {
    int_word.push_back(SPACE_ID);
  }
174

Y
Yibing Liu 已提交
175
  add_word_to_fst(int_word, dictionary);
Y
Yibing Liu 已提交
176
  return true;  // return with successful adding
177
}