ctc_decoders.cpp 11.5 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
6
#include <limits>
7
#include "ctc_decoders.h"
Y
Yibing Liu 已提交
8
#include "decoder_utils.h"
9
#include "ThreadPool.h"
Y
Yibing Liu 已提交
10

11
typedef double log_prob_type;
12

Y
Yibing Liu 已提交
13

Y
Yibing Liu 已提交
14
template <typename T1, typename T2>
15 16
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
17 18 19 20
    return a.first > b.first;
}

template <typename T1, typename T2>
21 22
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
23 24 25
    return a.second > b.second;
}

26 27 28 29
template <typename T>
T log_sum_exp(T x, T y)
{
    static T num_min = -std::numeric_limits<T>::max();
30 31
    if (x <= num_min) return y;
    if (y <= num_min) return x;
32 33 34 35 36
    T xmax = std::max(x, y);
    return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}

std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
37 38
                                  std::vector<std::string> vocabulary)
{
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    // dimension check
    int num_time_steps = probs_seq.size();
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
            std::cout<<"The shape of probs_seq does not match"
                     <<" with the shape of the vocabulary!"<<std::endl;
            exit(1);
        }
    }

    int blank_id = vocabulary.size();

    std::vector<int> max_idx_vec;
    double max_prob = 0.0;
    int max_idx = 0;
    for (int i=0; i<num_time_steps; i++) {
        for (int j=0; j<probs_seq[i].size(); j++) {
            if (max_prob < probs_seq[i][j]) {
                max_idx = j;
                max_prob = probs_seq[i][j];
            }
        }
        max_idx_vec.push_back(max_idx);
        max_prob = 0.0;
        max_idx = 0;
    }

    std::vector<int> idx_vec;
    for (int i=0; i<max_idx_vec.size(); i++) {
        if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
            idx_vec.push_back(max_idx_vec[i]);
        }
    }

    std::string best_path_result;
    for (int i=0; i<idx_vec.size(); i++) {
        if (idx_vec[i] != blank_id) {
76
            best_path_result += vocabulary[idx_vec[i]];
77 78 79 80 81
        }
    }
    return best_path_result;
}

Y
Yibing Liu 已提交
82
std::vector<std::pair<double, std::string> >
83 84 85 86 87
    ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
                            int beam_size,
                            std::vector<std::string> vocabulary,
                            int blank_id,
                            double cutoff_prob,
88 89
                            Scorer *ext_scorer)
{
90
    // dimension check
Y
Yibing Liu 已提交
91
    int num_time_steps = probs_seq.size();
92 93
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
94 95
            std::cout << " The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
96 97 98 99 100 101
            exit(1);
        }
    }

    // blank_id check
    if (blank_id > vocabulary.size()) {
102
        std::cout << " Invalid blank_id! " << std::endl;
103 104
        exit(1);
    }
Y
Yibing Liu 已提交
105 106

    // assign space ID
107 108 109
    std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
                                                  vocabulary.end(), " ");
    int space_id = it - vocabulary.begin();
Y
Yibing Liu 已提交
110
    if(space_id >= vocabulary.size()) {
111
        std::cout << " The character space is not in the vocabulary!"<<std::endl;
Y
Yibing Liu 已提交
112
        exit(1);
Y
Yibing Liu 已提交
113
    }
Y
Yibing Liu 已提交
114

Y
Yibing Liu 已提交
115 116
    // initialize
    // two sets containing selected and candidate prefixes respectively
117
    std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
Y
Yibing Liu 已提交
118
    // probability of prefixes ending with blank and non-blank
119 120 121 122 123 124 125
    std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
    std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;

    static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
    prefix_set_prev["\t"] = 0.0;
    log_probs_b_prev["\t"] = 0.0;
    log_probs_nb_prev["\t"] = -NUM_MAX;
Y
Yibing Liu 已提交
126

Y
Yibing Liu 已提交
127 128
    for (int time_step=0; time_step<num_time_steps; time_step++) {
        prefix_set_next.clear();
129 130
        log_probs_b_cur.clear();
        log_probs_nb_cur.clear();
Y
Yibing Liu 已提交
131 132 133 134 135 136
        std::vector<double> prob = probs_seq[time_step];

        std::vector<std::pair<int, double> > prob_idx;
        for (int i=0; i<prob.size(); i++) {
            prob_idx.push_back(std::pair<int, double>(i, prob[i]));
        }
137

Y
Yibing Liu 已提交
138
        // pruning of vacobulary
139
        int cutoff_len = prob.size();
Y
Yibing Liu 已提交
140
        if (cutoff_prob < 1.0) {
141 142
            std::sort(prob_idx.begin(),
                      prob_idx.end(),
143
                      pair_comp_second_rev<int, double>);
144 145
            double cum_prob = 0.0;
            cutoff_len = 0;
Y
Yibing Liu 已提交
146 147 148 149 150
            for (int i=0; i<prob_idx.size(); i++) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob) break;
            }
151
            prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
152
                            prob_idx.begin() + cutoff_len);
Y
Yibing Liu 已提交
153
        }
154 155 156 157 158 159 160

        std::vector<std::pair<int, log_prob_type> > log_prob_idx;
        for (int i=0; i<cutoff_len; i++) {
            log_prob_idx.push_back(std::pair<int, log_prob_type>
                        (prob_idx[i].first, log(prob_idx[i].second)));
        }

Y
Yibing Liu 已提交
161
        // extend prefix
162 163
        for (std::map<std::string, log_prob_type>::iterator
             it = prefix_set_prev.begin();
Y
Yibing Liu 已提交
164 165 166
            it != prefix_set_prev.end(); it++) {
            std::string l = it->first;
            if( prefix_set_next.find(l) == prefix_set_next.end()) {
167
                log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
Y
Yibing Liu 已提交
168 169
            }

170 171 172 173
            for (int index=0; index<log_prob_idx.size(); index++) {
                int c = log_prob_idx[index].first;
                log_prob_type log_prob_c = log_prob_idx[index].second;
                log_prob_type log_probs_prev;
Y
Yibing Liu 已提交
174
                if (c == blank_id) {
175 176 177 178
                    log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                 log_probs_nb_prev[l]);
                    log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
                                                     log_prob_c+log_probs_prev);
Y
Yibing Liu 已提交
179 180 181
                } else {
                    std::string last_char = l.substr(l.size()-1, 1);
                    std::string new_char = vocabulary[c];
182
                    std::string l_plus = l + new_char;
Y
Yibing Liu 已提交
183 184

                    if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
185 186
                        log_probs_b_cur[l_plus] = -NUM_MAX;
                        log_probs_nb_cur[l_plus] = -NUM_MAX;
Y
Yibing Liu 已提交
187 188
                    }
                    if (last_char == new_char) {
189 190 191 192 193 194 195 196
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_prob_c+log_probs_b_prev[l]
                                            );
                        log_probs_nb_cur[l] = log_sum_exp(
                                                log_probs_nb_cur[l],
                                                log_prob_c+log_probs_nb_prev[l]
                                            );
Y
Yibing Liu 已提交
197
                    } else if (new_char == " ") {
198
                        float score = 0.0;
Y
Yibing Liu 已提交
199
                        if (ext_scorer != NULL && l.size() > 1) {
200
                            score = ext_scorer->get_score(l.substr(1), true);
Y
Yibing Liu 已提交
201
                        }
202 203 204 205 206 207
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                            log_probs_nb_cur[l_plus],
                                            score + log_prob_c + log_probs_prev
                                        );
Y
Yibing Liu 已提交
208
                    } else {
209 210 211 212 213 214
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                    log_probs_nb_cur[l_plus],
                                                    log_prob_c+log_probs_prev
                                                );
Y
Yibing Liu 已提交
215
                    }
216 217 218 219
                    prefix_set_next[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_probs_b_cur[l_plus]
                                            );
Y
Yibing Liu 已提交
220 221 222
                }
            }

223 224
            prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
                                             log_probs_nb_cur[l]);
Y
Yibing Liu 已提交
225 226
        }

227 228 229
        log_probs_b_prev = log_probs_b_cur;
        log_probs_nb_prev = log_probs_nb_cur;
        std::vector<std::pair<std::string, log_prob_type> >
230 231 232 233
                  prefix_vec_next(prefix_set_next.begin(),
                                  prefix_set_next.end());
        std::sort(prefix_vec_next.begin(),
                  prefix_vec_next.end(),
234 235 236 237 238 239 240
                  pair_comp_second_rev<std::string, log_prob_type>);
        int num_prefixes_next = prefix_vec_next.size();
        int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
        prefix_set_prev = std::map<std::string, log_prob_type> (
                                                   prefix_vec_next.begin(),
                                                   prefix_vec_next.begin() + k
                                                );
Y
Yibing Liu 已提交
241 242 243 244
    }

    // post processing
    std::vector<std::pair<double, std::string> > beam_result;
245 246 247 248
    for (std::map<std::string, log_prob_type>::iterator
         it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
        if (it->second > -NUM_MAX && it->first.size() > 1) {
            log_prob_type log_prob = it->second;
Y
Yibing Liu 已提交
249 250 251
            std::string sentence = it->first.substr(1);
            // scoring the last word
            if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
252 253 254 255 256
                log_prob = log_prob + ext_scorer->get_score(sentence, true);
            }
            if (log_prob > -NUM_MAX) {
                std::pair<double, std::string> cur_result(log_prob, sentence);
                beam_result.push_back(cur_result);
Y
Yibing Liu 已提交
257 258 259 260
            }
        }
    }
    // sort the result and return
261 262
    std::sort(beam_result.begin(), beam_result.end(),
              pair_comp_first_rev<double, std::string>);
Y
Yibing Liu 已提交
263 264
    return beam_result;
}
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300


std::vector<std::vector<std::pair<double, std::string>>>
    ctc_beam_search_decoder_batch(
                std::vector<std::vector<std::vector<double>>> probs_split,
                int beam_size,
                std::vector<std::string> vocabulary,
                int blank_id,
                int num_processes,
                double cutoff_prob,
                Scorer *ext_scorer
                )
{
    if (num_processes <= 0) {
        std::cout << "num_processes must be nonnegative!" << std::endl;
        exit(1);
    }
    // thread pool
    ThreadPool pool(num_processes);
    // number of samples
    int batch_size = probs_split.size();
    // enqueue the tasks of decoding
    std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
    for (int i = 0; i < batch_size; i++) {
        res.emplace_back(
                pool.enqueue(ctc_beam_search_decoder, probs_split[i],
                    beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer)
            );
    }
    // get decoding results
    std::vector<std::vector<std::pair<double, std::string>>> batch_results;
    for (int i = 0; i < batch_size; i++) {
        batch_results.emplace_back(res[i].get());
    }
    return batch_results;
}