ctc_decoders.cpp 11.0 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
6
#include <limits>
7
#include "ctc_decoders.h"
Y
Yibing Liu 已提交
8
#include "decoder_utils.h"
9
#include "ThreadPool.h"
Y
Yibing Liu 已提交
10

11
typedef double log_prob_type;
12 13

std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
14 15
                                  std::vector<std::string> vocabulary)
{
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    // dimension check
    int num_time_steps = probs_seq.size();
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
            std::cout<<"The shape of probs_seq does not match"
                     <<" with the shape of the vocabulary!"<<std::endl;
            exit(1);
        }
    }

    int blank_id = vocabulary.size();

    std::vector<int> max_idx_vec;
    double max_prob = 0.0;
    int max_idx = 0;
    for (int i=0; i<num_time_steps; i++) {
        for (int j=0; j<probs_seq[i].size(); j++) {
            if (max_prob < probs_seq[i][j]) {
                max_idx = j;
                max_prob = probs_seq[i][j];
            }
        }
        max_idx_vec.push_back(max_idx);
        max_prob = 0.0;
        max_idx = 0;
    }

    std::vector<int> idx_vec;
    for (int i=0; i<max_idx_vec.size(); i++) {
        if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
            idx_vec.push_back(max_idx_vec[i]);
        }
    }

    std::string best_path_result;
    for (int i=0; i<idx_vec.size(); i++) {
        if (idx_vec[i] != blank_id) {
53
            best_path_result += vocabulary[idx_vec[i]];
54 55 56 57 58
        }
    }
    return best_path_result;
}

Y
Yibing Liu 已提交
59
std::vector<std::pair<double, std::string> >
60 61 62 63 64
    ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
                            int beam_size,
                            std::vector<std::string> vocabulary,
                            int blank_id,
                            double cutoff_prob,
65 66
                            Scorer *ext_scorer)
{
67
    // dimension check
Y
Yibing Liu 已提交
68
    int num_time_steps = probs_seq.size();
69 70
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
71 72
            std::cout << " The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
73 74 75 76 77 78
            exit(1);
        }
    }

    // blank_id check
    if (blank_id > vocabulary.size()) {
79
        std::cout << " Invalid blank_id! " << std::endl;
80 81
        exit(1);
    }
Y
Yibing Liu 已提交
82 83

    // assign space ID
84 85 86
    std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
                                                  vocabulary.end(), " ");
    int space_id = it - vocabulary.begin();
Y
Yibing Liu 已提交
87
    if(space_id >= vocabulary.size()) {
88
        std::cout << " The character space is not in the vocabulary!"<<std::endl;
Y
Yibing Liu 已提交
89
        exit(1);
Y
Yibing Liu 已提交
90
    }
Y
Yibing Liu 已提交
91

Y
Yibing Liu 已提交
92 93
    // initialize
    // two sets containing selected and candidate prefixes respectively
94
    std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
Y
Yibing Liu 已提交
95
    // probability of prefixes ending with blank and non-blank
96 97 98 99 100 101 102
    std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
    std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;

    static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
    prefix_set_prev["\t"] = 0.0;
    log_probs_b_prev["\t"] = 0.0;
    log_probs_nb_prev["\t"] = -NUM_MAX;
Y
Yibing Liu 已提交
103

Y
Yibing Liu 已提交
104 105
    for (int time_step=0; time_step<num_time_steps; time_step++) {
        prefix_set_next.clear();
106 107
        log_probs_b_cur.clear();
        log_probs_nb_cur.clear();
Y
Yibing Liu 已提交
108 109 110 111 112 113
        std::vector<double> prob = probs_seq[time_step];

        std::vector<std::pair<int, double> > prob_idx;
        for (int i=0; i<prob.size(); i++) {
            prob_idx.push_back(std::pair<int, double>(i, prob[i]));
        }
114

Y
Yibing Liu 已提交
115
        // pruning of vacobulary
116
        int cutoff_len = prob.size();
Y
Yibing Liu 已提交
117
        if (cutoff_prob < 1.0) {
118 119
            std::sort(prob_idx.begin(),
                      prob_idx.end(),
120
                      pair_comp_second_rev<int, double>);
121 122
            double cum_prob = 0.0;
            cutoff_len = 0;
Y
Yibing Liu 已提交
123 124 125 126 127
            for (int i=0; i<prob_idx.size(); i++) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob) break;
            }
128
            prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
129
                            prob_idx.begin() + cutoff_len);
Y
Yibing Liu 已提交
130
        }
131 132 133 134 135 136 137

        std::vector<std::pair<int, log_prob_type> > log_prob_idx;
        for (int i=0; i<cutoff_len; i++) {
            log_prob_idx.push_back(std::pair<int, log_prob_type>
                        (prob_idx[i].first, log(prob_idx[i].second)));
        }

Y
Yibing Liu 已提交
138
        // extend prefix
139 140
        for (std::map<std::string, log_prob_type>::iterator
             it = prefix_set_prev.begin();
Y
Yibing Liu 已提交
141 142 143
            it != prefix_set_prev.end(); it++) {
            std::string l = it->first;
            if( prefix_set_next.find(l) == prefix_set_next.end()) {
144
                log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
Y
Yibing Liu 已提交
145 146
            }

147 148 149 150
            for (int index=0; index<log_prob_idx.size(); index++) {
                int c = log_prob_idx[index].first;
                log_prob_type log_prob_c = log_prob_idx[index].second;
                log_prob_type log_probs_prev;
Y
Yibing Liu 已提交
151
                if (c == blank_id) {
152 153 154 155
                    log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                 log_probs_nb_prev[l]);
                    log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
                                                     log_prob_c+log_probs_prev);
Y
Yibing Liu 已提交
156 157 158
                } else {
                    std::string last_char = l.substr(l.size()-1, 1);
                    std::string new_char = vocabulary[c];
159
                    std::string l_plus = l + new_char;
Y
Yibing Liu 已提交
160 161

                    if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
162 163
                        log_probs_b_cur[l_plus] = -NUM_MAX;
                        log_probs_nb_cur[l_plus] = -NUM_MAX;
Y
Yibing Liu 已提交
164 165
                    }
                    if (last_char == new_char) {
166 167 168 169 170 171 172 173
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_prob_c+log_probs_b_prev[l]
                                            );
                        log_probs_nb_cur[l] = log_sum_exp(
                                                log_probs_nb_cur[l],
                                                log_prob_c+log_probs_nb_prev[l]
                                            );
Y
Yibing Liu 已提交
174
                    } else if (new_char == " ") {
175
                        float score = 0.0;
Y
Yibing Liu 已提交
176
                        if (ext_scorer != NULL && l.size() > 1) {
177
                            score = ext_scorer->get_score(l.substr(1), true);
Y
Yibing Liu 已提交
178
                        }
179 180 181 182 183 184
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                            log_probs_nb_cur[l_plus],
                                            score + log_prob_c + log_probs_prev
                                        );
Y
Yibing Liu 已提交
185
                    } else {
186 187 188 189 190 191
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                    log_probs_nb_cur[l_plus],
                                                    log_prob_c+log_probs_prev
                                                );
Y
Yibing Liu 已提交
192
                    }
193 194 195 196
                    prefix_set_next[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_probs_b_cur[l_plus]
                                            );
Y
Yibing Liu 已提交
197 198 199
                }
            }

200 201
            prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
                                             log_probs_nb_cur[l]);
Y
Yibing Liu 已提交
202 203
        }

204 205 206
        log_probs_b_prev = log_probs_b_cur;
        log_probs_nb_prev = log_probs_nb_cur;
        std::vector<std::pair<std::string, log_prob_type> >
207 208 209 210
                  prefix_vec_next(prefix_set_next.begin(),
                                  prefix_set_next.end());
        std::sort(prefix_vec_next.begin(),
                  prefix_vec_next.end(),
211 212 213 214 215 216 217
                  pair_comp_second_rev<std::string, log_prob_type>);
        int num_prefixes_next = prefix_vec_next.size();
        int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
        prefix_set_prev = std::map<std::string, log_prob_type> (
                                                   prefix_vec_next.begin(),
                                                   prefix_vec_next.begin() + k
                                                );
Y
Yibing Liu 已提交
218 219 220 221
    }

    // post processing
    std::vector<std::pair<double, std::string> > beam_result;
222 223 224 225
    for (std::map<std::string, log_prob_type>::iterator
         it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
        if (it->second > -NUM_MAX && it->first.size() > 1) {
            log_prob_type log_prob = it->second;
Y
Yibing Liu 已提交
226 227 228
            std::string sentence = it->first.substr(1);
            // scoring the last word
            if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
229 230 231 232 233
                log_prob = log_prob + ext_scorer->get_score(sentence, true);
            }
            if (log_prob > -NUM_MAX) {
                std::pair<double, std::string> cur_result(log_prob, sentence);
                beam_result.push_back(cur_result);
Y
Yibing Liu 已提交
234 235 236 237
            }
        }
    }
    // sort the result and return
238 239
    std::sort(beam_result.begin(), beam_result.end(),
              pair_comp_first_rev<double, std::string>);
Y
Yibing Liu 已提交
240 241
    return beam_result;
}
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277


std::vector<std::vector<std::pair<double, std::string>>>
    ctc_beam_search_decoder_batch(
                std::vector<std::vector<std::vector<double>>> probs_split,
                int beam_size,
                std::vector<std::string> vocabulary,
                int blank_id,
                int num_processes,
                double cutoff_prob,
                Scorer *ext_scorer
                )
{
    if (num_processes <= 0) {
        std::cout << "num_processes must be nonnegative!" << std::endl;
        exit(1);
    }
    // thread pool
    ThreadPool pool(num_processes);
    // number of samples
    int batch_size = probs_split.size();
    // enqueue the tasks of decoding
    std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
    for (int i = 0; i < batch_size; i++) {
        res.emplace_back(
                pool.enqueue(ctc_beam_search_decoder, probs_split[i],
                    beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer)
            );
    }
    // get decoding results
    std::vector<std::vector<std::pair<double, std::string>>> batch_results;
    for (int i = 0; i < batch_size; i++) {
        batch_results.emplace_back(res[i].get());
    }
    return batch_results;
}