ctc_decoders.cpp 10.2 KB
Newer Older
Y
Yibing Liu 已提交
1
#include "ctc_decoders.h"
2

Y
Yibing Liu 已提交
3 4
#include <algorithm>
#include <cmath>
Y
Yibing Liu 已提交
5
#include <iostream>
6
#include <limits>
Y
Yibing Liu 已提交
7 8
#include <map>
#include <utility>
9

Y
Yibing Liu 已提交
10
#include "ThreadPool.h"
Y
Yibing Liu 已提交
11
#include "fst/fstlib.h"
12

Y
Yibing Liu 已提交
13
#include "decoder_utils.h"
14
#include "path_trie.h"
Y
Yibing Liu 已提交
15

16
std::string ctc_greedy_decoder(
Y
Yibing Liu 已提交
17 18
    const std::vector<std::vector<double>> &probs_seq,
    const std::vector<std::string> &vocabulary) {
Y
Yibing Liu 已提交
19
  // dimension check
Y
Yibing Liu 已提交
20 21 22 23 24 25
  size_t num_time_steps = probs_seq.size();
  for (size_t i = 0; i < num_time_steps; i++) {
    VALID_CHECK_EQ(probs_seq[i].size(),
                   vocabulary.size() + 1,
                   "The shape of probs_seq does not match with "
                   "the shape of the vocabulary");
Y
Yibing Liu 已提交
26 27
  }

Y
Yibing Liu 已提交
28
  size_t blank_id = vocabulary.size();
Y
Yibing Liu 已提交
29

Y
Yibing Liu 已提交
30 31 32 33 34
  std::vector<size_t> max_idx_vec;
  for (size_t i = 0; i < num_time_steps; i++) {
    double max_prob = 0.0;
    size_t max_idx = 0;
    for (size_t j = 0; j < probs_seq[i].size(); j++) {
Y
Yibing Liu 已提交
35 36 37 38
      if (max_prob < probs_seq[i][j]) {
        max_idx = j;
        max_prob = probs_seq[i][j];
      }
39
    }
Y
Yibing Liu 已提交
40 41 42
    max_idx_vec.push_back(max_idx);
  }

Y
Yibing Liu 已提交
43 44
  std::vector<size_t> idx_vec;
  for (size_t i = 0; i < max_idx_vec.size(); i++) {
Y
Yibing Liu 已提交
45 46
    if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
      idx_vec.push_back(max_idx_vec[i]);
47
    }
Y
Yibing Liu 已提交
48
  }
49

Y
Yibing Liu 已提交
50
  std::string best_path_result;
Y
Yibing Liu 已提交
51
  for (size_t i = 0; i < idx_vec.size(); i++) {
Y
Yibing Liu 已提交
52 53
    if (idx_vec[i] != blank_id) {
      best_path_result += vocabulary[idx_vec[i]];
54
    }
Y
Yibing Liu 已提交
55 56
  }
  return best_path_result;
57 58
}

Y
Yibing Liu 已提交
59
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Y
Yibing Liu 已提交
60
    const std::vector<std::vector<double>> &probs_seq,
Y
Yibing Liu 已提交
61
    const size_t beam_size,
Y
Yibing Liu 已提交
62
    std::vector<std::string> vocabulary,
Y
Yibing Liu 已提交
63 64 65
    const double cutoff_prob,
    const size_t cutoff_top_n,
    Scorer *ext_scorer) {
Y
Yibing Liu 已提交
66
  // dimension check
67
  size_t num_time_steps = probs_seq.size();
Y
Yibing Liu 已提交
68 69 70 71 72
  for (size_t i = 0; i < num_time_steps; i++) {
    VALID_CHECK_EQ(probs_seq[i].size(),
                   vocabulary.size() + 1,
                   "The shape of probs_seq does not match with "
                   "the shape of the vocabulary");
Y
Yibing Liu 已提交
73 74
  }

Y
Yibing Liu 已提交
75 76
  // assign blank id
  size_t blank_id = vocabulary.size();
Y
Yibing Liu 已提交
77

Y
Yibing Liu 已提交
78
  // assign space id
Y
Yibing Liu 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92
  std::vector<std::string>::iterator it =
      std::find(vocabulary.begin(), vocabulary.end(), " ");
  int space_id = it - vocabulary.begin();
  // if no space in vocabulary
  if (space_id >= vocabulary.size()) {
    space_id = -2;
  }

  // init prefixes' root
  PathTrie root;
  root.score = root.log_prob_b_prev = 0.0;
  std::vector<PathTrie *> prefixes;
  prefixes.push_back(&root);

Y
Yibing Liu 已提交
93 94 95
  if (ext_scorer != nullptr) {
    if (ext_scorer->is_char_map_empty()) {
      ext_scorer->set_char_map(vocabulary);
96
    }
Y
Yibing Liu 已提交
97 98
    if (!ext_scorer->is_character_based()) {
      if (ext_scorer->dictionary == nullptr) {
Y
Yibing Liu 已提交
99
        // fill dictionary for fst with space
Y
Yibing Liu 已提交
100
        ext_scorer->fill_dictionary(true);
Y
Yibing Liu 已提交
101
      }
Y
Yibing Liu 已提交
102
      auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
Y
Yibing Liu 已提交
103 104 105 106 107 108 109 110
      fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
      root.set_dictionary(dict_ptr);
      auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
      root.set_matcher(matcher);
    }
  }

  // prefix search over time
Y
Yibing Liu 已提交
111
  for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
Y
Yibing Liu 已提交
112 113
    std::vector<double> prob = probs_seq[time_step];
    std::vector<std::pair<int, double>> prob_idx;
Y
Yibing Liu 已提交
114
    for (size_t i = 0; i < prob.size(); i++) {
Y
Yibing Liu 已提交
115
      prob_idx.push_back(std::pair<int, double>(i, prob[i]));
Y
Yibing Liu 已提交
116
    }
Y
Yibing Liu 已提交
117

Y
Yibing Liu 已提交
118 119
    float min_cutoff = -NUM_FLT_INF;
    bool full_beam = false;
Y
Yibing Liu 已提交
120 121
    if (ext_scorer != nullptr) {
      size_t num_prefixes = std::min(prefixes.size(), beam_size);
Y
Yibing Liu 已提交
122 123 124
      std::sort(
          prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
      min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) -
Y
Yibing Liu 已提交
125
                   std::max(0.0, ext_scorer->beta);
Y
Yibing Liu 已提交
126 127
      full_beam = (num_prefixes == beam_size);
    }
128

Y
Yibing Liu 已提交
129
    // pruning of vacobulary
Y
Yibing Liu 已提交
130
    size_t cutoff_len = prob.size();
Y
Yibing Liu 已提交
131 132 133 134 135 136
    if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
      std::sort(
          prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
      if (cutoff_prob < 1.0) {
        double cum_prob = 0.0;
        cutoff_len = 0;
Y
Yibing Liu 已提交
137
        for (size_t i = 0; i < prob_idx.size(); i++) {
Y
Yibing Liu 已提交
138 139 140
          cum_prob += prob_idx[i].second;
          cutoff_len += 1;
          if (cum_prob >= cutoff_prob) break;
141
        }
Y
Yibing Liu 已提交
142 143 144 145 146
      }
      cutoff_len = std::min(cutoff_len, cutoff_top_n);
      prob_idx = std::vector<std::pair<int, double>>(
          prob_idx.begin(), prob_idx.begin() + cutoff_len);
    }
Y
Yibing Liu 已提交
147 148
    std::vector<std::pair<size_t, float>> log_prob_idx;
    for (size_t i = 0; i < cutoff_len; i++) {
Y
Yibing Liu 已提交
149 150
      log_prob_idx.push_back(std::pair<int, float>(
          prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
151
    }
Y
Yibing Liu 已提交
152

Y
Yibing Liu 已提交
153
    // loop over chars
Y
Yibing Liu 已提交
154
    for (size_t index = 0; index < log_prob_idx.size(); index++) {
Y
Yibing Liu 已提交
155 156
      auto c = log_prob_idx[index].first;
      float log_prob_c = log_prob_idx[index].second;
157

Y
Yibing Liu 已提交
158
      for (size_t i = 0; i < prefixes.size() && i < beam_size; i++) {
Y
Yibing Liu 已提交
159
        auto prefix = prefixes[i];
160

Y
Yibing Liu 已提交
161 162
        if (full_beam && log_prob_c + prefix->score < min_cutoff) {
          break;
163
        }
Y
Yibing Liu 已提交
164 165 166 167 168 169 170 171 172 173
        // blank
        if (c == blank_id) {
          prefix->log_prob_b_cur =
              log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
          continue;
        }
        // repeated character
        if (c == prefix->character) {
          prefix->log_prob_nb_cur = log_sum_exp(
              prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
Y
Yibing Liu 已提交
174
        }
Y
Yibing Liu 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188
        // get new prefix
        auto prefix_new = prefix->get_path_trie(c);

        if (prefix_new != nullptr) {
          float log_p = -NUM_FLT_INF;

          if (c == prefix->character &&
              prefix->log_prob_b_prev > -NUM_FLT_INF) {
            log_p = log_prob_c + prefix->log_prob_b_prev;
          } else if (c != prefix->character) {
            log_p = log_prob_c + prefix->score;
          }

          // language model scoring
Y
Yibing Liu 已提交
189 190
          if (ext_scorer != nullptr &&
              (c == space_id || ext_scorer->is_character_based())) {
Y
Yibing Liu 已提交
191 192 193
            PathTrie *prefix_toscore = nullptr;

            // skip scoring the space
Y
Yibing Liu 已提交
194
            if (ext_scorer->is_character_based()) {
Y
Yibing Liu 已提交
195 196 197 198
              prefix_toscore = prefix_new;
            } else {
              prefix_toscore = prefix;
            }
Y
Yibing Liu 已提交
199

Y
Yibing Liu 已提交
200 201
            double score = 0.0;
            std::vector<std::string> ngram;
Y
Yibing Liu 已提交
202 203
            ngram = ext_scorer->make_ngram(prefix_toscore);
            score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
Y
Yibing Liu 已提交
204

Y
Yibing Liu 已提交
205
            log_p += score;
Y
Yibing Liu 已提交
206
            log_p += ext_scorer->beta;
Y
Yibing Liu 已提交
207 208 209
          }
          prefix_new->log_prob_nb_cur =
              log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
Y
Yibing Liu 已提交
210
        }
Y
Yibing Liu 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
      }  // end of loop over prefix
    }    // end of loop over chars

    prefixes.clear();
    // update log probs
    root.iterate_to_vec(prefixes);

    // only preserve top beam_size prefixes
    if (prefixes.size() >= beam_size) {
      std::nth_element(prefixes.begin(),
                       prefixes.begin() + beam_size,
                       prefixes.end(),
                       prefix_compare);

      for (size_t i = beam_size; i < prefixes.size(); i++) {
        prefixes[i]->remove();
      }
Y
Yibing Liu 已提交
228
    }
Y
Yibing Liu 已提交
229 230 231 232 233 234
  }  // end of loop over time

  // compute aproximate ctc score as the return score
  for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
    double approx_ctc = prefixes[i]->score;

Y
Yibing Liu 已提交
235
    if (ext_scorer != nullptr) {
Y
Yibing Liu 已提交
236 237 238
      std::vector<int> output;
      prefixes[i]->get_path_vec(output);
      size_t prefix_length = output.size();
Y
Yibing Liu 已提交
239
      auto words = ext_scorer->split_labels(output);
Y
Yibing Liu 已提交
240
      // remove word insert
Y
Yibing Liu 已提交
241
      approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
Y
Yibing Liu 已提交
242
      // remove language model weight:
Y
Yibing Liu 已提交
243
      approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
244 245
    }

Y
Yibing Liu 已提交
246 247 248 249 250 251 252 253
    prefixes[i]->approx_ctc = approx_ctc;
  }

  // allow for the post processing
  std::vector<PathTrie *> space_prefixes;
  if (space_prefixes.empty()) {
    for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
      space_prefixes.push_back(prefixes[i]);
254
    }
Y
Yibing Liu 已提交
255 256 257 258 259 260 261 262 263
  }

  std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
  std::vector<std::pair<double, std::string>> output_vecs;
  for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
    std::vector<int> output;
    space_prefixes[i]->get_path_vec(output);
    // convert index to string
    std::string output_str;
Y
Yibing Liu 已提交
264
    for (size_t j = 0; j < output.size(); j++) {
Y
Yibing Liu 已提交
265
      output_str += vocabulary[output[j]];
266
    }
Y
Yibing Liu 已提交
267 268 269 270
    std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
                                               output_str);
    output_vecs.emplace_back(output_pair);
  }
271

Y
Yibing Liu 已提交
272 273
  return output_vecs;
}
274

Y
Yibing Liu 已提交
275 276
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
Y
Yibing Liu 已提交
277
    const std::vector<std::vector<std::vector<double>>> &probs_split,
Y
Yibing Liu 已提交
278
    const size_t beam_size,
Y
Yibing Liu 已提交
279
    const std::vector<std::string> &vocabulary,
Y
Yibing Liu 已提交
280 281 282 283 284
    const size_t num_processes,
    const double cutoff_prob,
    const size_t cutoff_top_n,
    Scorer *ext_scorer) {
  VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
Y
Yibing Liu 已提交
285 286 287
  // thread pool
  ThreadPool pool(num_processes);
  // number of samples
Y
Yibing Liu 已提交
288
  size_t batch_size = probs_split.size();
Y
Yibing Liu 已提交
289 290

  // scorer filling up
Y
Yibing Liu 已提交
291 292 293
  if (ext_scorer != nullptr) {
    if (ext_scorer->is_char_map_empty()) {
      ext_scorer->set_char_map(vocabulary);
Y
Yibing Liu 已提交
294
    }
Y
Yibing Liu 已提交
295 296
    if (!ext_scorer->is_character_based() &&
        ext_scorer->dictionary == nullptr) {
Y
Yibing Liu 已提交
297
      // init dictionary
Y
Yibing Liu 已提交
298
      ext_scorer->fill_dictionary(true);
299
    }
Y
Yibing Liu 已提交
300 301 302 303
  }

  // enqueue the tasks of decoding
  std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
Y
Yibing Liu 已提交
304
  for (size_t i = 0; i < batch_size; i++) {
Y
Yibing Liu 已提交
305 306 307 308 309 310
    res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
                                  probs_split[i],
                                  beam_size,
                                  vocabulary,
                                  cutoff_prob,
                                  cutoff_top_n,
Y
Yibing Liu 已提交
311
                                  ext_scorer));
Y
Yibing Liu 已提交
312 313 314 315
  }

  // get decoding results
  std::vector<std::vector<std::pair<double, std::string>>> batch_results;
Y
Yibing Liu 已提交
316
  for (size_t i = 0; i < batch_size; i++) {
Y
Yibing Liu 已提交
317 318 319
    batch_results.emplace_back(res[i].get());
  }
  return batch_results;
320
}