decoder_utils.cpp 4.8 KB
Newer Older
Y
Yibing Liu 已提交
1
#include "decoder_utils.h"
2

Y
Yibing Liu 已提交
3 4
#include <algorithm>
#include <cmath>
Y
Yibing Liu 已提交
5
#include <limits>
Y
Yibing Liu 已提交
6

7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
    const std::vector<double> &prob_step,
    double cutoff_prob,
    size_t cutoff_top_n) {
  std::vector<std::pair<int, double>> prob_idx;
  for (size_t i = 0; i < prob_step.size(); ++i) {
    prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
  }
  // pruning of vacobulary
  size_t cutoff_len = prob_step.size();
  if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
    std::sort(
        prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
    if (cutoff_prob < 1.0) {
      double cum_prob = 0.0;
      cutoff_len = 0;
      for (size_t i = 0; i < prob_idx.size(); ++i) {
        cum_prob += prob_idx[i].second;
        cutoff_len += 1;
Y
Yibing Liu 已提交
26
        if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
      }
    }
    prob_idx = std::vector<std::pair<int, double>>(
        prob_idx.begin(), prob_idx.begin() + cutoff_len);
  }
  std::vector<std::pair<size_t, float>> log_prob_idx;
  for (size_t i = 0; i < cutoff_len; ++i) {
    log_prob_idx.push_back(std::pair<int, float>(
        prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
  }
  return log_prob_idx;
}


std::vector<std::pair<double, std::string>> get_beam_search_result(
    const std::vector<PathTrie *> &prefixes,
    const std::vector<std::string> &vocabulary,
    size_t beam_size) {
  // allow for the post processing
  std::vector<PathTrie *> space_prefixes;
  if (space_prefixes.empty()) {
    for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
      space_prefixes.push_back(prefixes[i]);
    }
  }

  std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
  std::vector<std::pair<double, std::string>> output_vecs;
  for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
    std::vector<int> output;
    space_prefixes[i]->get_path_vec(output);
    // convert index to string
    std::string output_str;
    for (size_t j = 0; j < output.size(); j++) {
      output_str += vocabulary[output[j]];
    }
    std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
                                               output_str);
    output_vecs.emplace_back(output_pair);
  }

  return output_vecs;
}

Y
Yibing Liu 已提交
71
size_t get_utf8_str_len(const std::string &str) {
Y
Yibing Liu 已提交
72 73 74 75 76
  size_t str_len = 0;
  for (char c : str) {
    str_len += ((c & 0xc0) != 0x80);
  }
  return str_len;
77
}
78

Y
Yibing Liu 已提交
79
std::vector<std::string> split_utf8_str(const std::string &str) {
80 81 82
  std::vector<std::string> result;
  std::string out_str;

Y
Yibing Liu 已提交
83 84
  for (char c : str) {
    if ((c & 0xc0) != 0x80)  // new UTF-8 character
85
    {
Y
Yibing Liu 已提交
86 87 88 89
      if (!out_str.empty()) {
        result.push_back(out_str);
        out_str.clear();
      }
90
    }
Y
Yibing Liu 已提交
91 92 93

    out_str.append(1, c);
  }
94 95 96 97
  result.push_back(out_str);
  return result;
}

Y
Yibing Liu 已提交
98 99
std::vector<std::string> split_str(const std::string &s,
                                   const std::string &delim) {
Y
Yibing Liu 已提交
100 101 102 103 104 105 106 107 108 109 110 111
  std::vector<std::string> result;
  std::size_t start = 0, delim_len = delim.size();
  while (true) {
    std::size_t end = s.find(delim, start);
    if (end == std::string::npos) {
      if (start < s.size()) {
        result.push_back(s.substr(start));
      }
      break;
    }
    if (end > start) {
      result.push_back(s.substr(start, end - start));
Y
Yibing Liu 已提交
112
    }
Y
Yibing Liu 已提交
113 114 115
    start = end + delim_len;
  }
  return result;
Y
Yibing Liu 已提交
116 117
}

Y
Yibing Liu 已提交
118
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
Y
Yibing Liu 已提交
119 120 121
  if (x->score == y->score) {
    if (x->character == y->character) {
      return false;
122
    } else {
Y
Yibing Liu 已提交
123
      return (x->character < y->character);
124
    }
Y
Yibing Liu 已提交
125 126 127
  } else {
    return x->score > y->score;
  }
128
}
129

Y
Yibing Liu 已提交
130 131
void add_word_to_fst(const std::vector<int> &word,
                     fst::StdVectorFst *dictionary) {
Y
Yibing Liu 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144
  if (dictionary->NumStates() == 0) {
    fst::StdVectorFst::StateId start = dictionary->AddState();
    assert(start == 0);
    dictionary->SetStart(start);
  }
  fst::StdVectorFst::StateId src = dictionary->Start();
  fst::StdVectorFst::StateId dst;
  for (auto c : word) {
    dst = dictionary->AddState();
    dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
    src = dst;
  }
  dictionary->SetFinal(dst, fst::StdArc::Weight::One());
145
}
146

Y
Yibing Liu 已提交
147
bool add_word_to_dictionary(
Y
Yibing Liu 已提交
148 149
    const std::string &word,
    const std::unordered_map<std::string, int> &char_map,
Y
Yibing Liu 已提交
150 151
    bool add_space,
    int SPACE_ID,
Y
Yibing Liu 已提交
152
    fst::StdVectorFst *dictionary) {
Y
Yibing Liu 已提交
153
  auto characters = split_utf8_str(word);
154

Y
Yibing Liu 已提交
155
  std::vector<int> int_word;
156

Y
Yibing Liu 已提交
157
  for (auto &c : characters) {
Y
Yibing Liu 已提交
158 159 160 161 162 163 164 165 166
    if (c == " ") {
      int_word.push_back(SPACE_ID);
    } else {
      auto int_c = char_map.find(c);
      if (int_c != char_map.end()) {
        int_word.push_back(int_c->second);
      } else {
        return false;  // return without adding
      }
167
    }
Y
Yibing Liu 已提交
168
  }
169

Y
Yibing Liu 已提交
170 171 172
  if (add_space) {
    int_word.push_back(SPACE_ID);
  }
173

Y
Yibing Liu 已提交
174
  add_word_to_fst(int_word, dictionary);
Y
Yibing Liu 已提交
175
  return true;  // return with successful adding
176
}