ctc_decoders.cpp 10.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
6
#include <limits>
7
#include "ctc_decoders.h"
Y
Yibing Liu 已提交
8

9
typedef double log_prob_type;
10

Y
Yibing Liu 已提交
11
template <typename T1, typename T2>
12 13
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
14 15 16 17
    return a.first > b.first;
}

template <typename T1, typename T2>
18 19
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
20 21 22
    return a.second > b.second;
}

23 24 25 26
template <typename T>
T log_sum_exp(T x, T y)
{
    static T num_min = -std::numeric_limits<T>::max();
27 28
    if (x <= num_min) return y;
    if (y <= num_min) return x;
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    T xmax = std::max(x, y);
    return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}

std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
                                  std::vector<std::string> vocabulary) {
    // dimension check
    int num_time_steps = probs_seq.size();
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
            std::cout<<"The shape of probs_seq does not match"
                     <<" with the shape of the vocabulary!"<<std::endl;
            exit(1);
        }
    }

    int blank_id = vocabulary.size();

    std::vector<int> max_idx_vec;
    double max_prob = 0.0;
    int max_idx = 0;
    for (int i=0; i<num_time_steps; i++) {
        for (int j=0; j<probs_seq[i].size(); j++) {
            if (max_prob < probs_seq[i][j]) {
                max_idx = j;
                max_prob = probs_seq[i][j];
            }
        }
        max_idx_vec.push_back(max_idx);
        max_prob = 0.0;
        max_idx = 0;
    }

    std::vector<int> idx_vec;
    for (int i=0; i<max_idx_vec.size(); i++) {
        if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
            idx_vec.push_back(max_idx_vec[i]);
        }
    }

    std::string best_path_result;
    for (int i=0; i<idx_vec.size(); i++) {
        if (idx_vec[i] != blank_id) {
72
            best_path_result += vocabulary[idx_vec[i]];
73 74 75 76 77
        }
    }
    return best_path_result;
}

Y
Yibing Liu 已提交
78
std::vector<std::pair<double, std::string> >
79 80 81 82 83
    ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
                            int beam_size,
                            std::vector<std::string> vocabulary,
                            int blank_id,
                            double cutoff_prob,
84
                            LmScorer *ext_scorer,
85 86
                            bool nproc) {
    // dimension check
Y
Yibing Liu 已提交
87
    int num_time_steps = probs_seq.size();
88 89
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
90 91
            std::cout << " The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
92 93 94 95 96 97
            exit(1);
        }
    }

    // blank_id check
    if (blank_id > vocabulary.size()) {
98
        std::cout << " Invalid blank_id! " << std::endl;
99 100
        exit(1);
    }
Y
Yibing Liu 已提交
101 102

    // assign space ID
103 104 105
    std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
                                                  vocabulary.end(), " ");
    int space_id = it - vocabulary.begin();
Y
Yibing Liu 已提交
106
    if(space_id >= vocabulary.size()) {
107
        std::cout << " The character space is not in the vocabulary!"<<std::endl;
Y
Yibing Liu 已提交
108
        exit(1);
Y
Yibing Liu 已提交
109
    }
Y
Yibing Liu 已提交
110

Y
Yibing Liu 已提交
111 112
    // initialize
    // two sets containing selected and candidate prefixes respectively
113
    std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
Y
Yibing Liu 已提交
114
    // probability of prefixes ending with blank and non-blank
115 116 117 118 119 120 121
    std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
    std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;

    static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
    prefix_set_prev["\t"] = 0.0;
    log_probs_b_prev["\t"] = 0.0;
    log_probs_nb_prev["\t"] = -NUM_MAX;
Y
Yibing Liu 已提交
122

Y
Yibing Liu 已提交
123 124
    for (int time_step=0; time_step<num_time_steps; time_step++) {
        prefix_set_next.clear();
125 126
        log_probs_b_cur.clear();
        log_probs_nb_cur.clear();
Y
Yibing Liu 已提交
127 128 129 130 131 132
        std::vector<double> prob = probs_seq[time_step];

        std::vector<std::pair<int, double> > prob_idx;
        for (int i=0; i<prob.size(); i++) {
            prob_idx.push_back(std::pair<int, double>(i, prob[i]));
        }
133

Y
Yibing Liu 已提交
134
        // pruning of vacobulary
135
        int cutoff_len = prob.size();
Y
Yibing Liu 已提交
136
        if (cutoff_prob < 1.0) {
137 138
            std::sort(prob_idx.begin(),
                      prob_idx.end(),
139
                      pair_comp_second_rev<int, double>);
140 141
            double cum_prob = 0.0;
            cutoff_len = 0;
Y
Yibing Liu 已提交
142 143 144 145 146
            for (int i=0; i<prob_idx.size(); i++) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob) break;
            }
147
            prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
148
                            prob_idx.begin() + cutoff_len);
Y
Yibing Liu 已提交
149
        }
150 151 152 153 154 155 156

        std::vector<std::pair<int, log_prob_type> > log_prob_idx;
        for (int i=0; i<cutoff_len; i++) {
            log_prob_idx.push_back(std::pair<int, log_prob_type>
                        (prob_idx[i].first, log(prob_idx[i].second)));
        }

Y
Yibing Liu 已提交
157
        // extend prefix
158 159
        for (std::map<std::string, log_prob_type>::iterator
             it = prefix_set_prev.begin();
Y
Yibing Liu 已提交
160 161 162
            it != prefix_set_prev.end(); it++) {
            std::string l = it->first;
            if( prefix_set_next.find(l) == prefix_set_next.end()) {
163
                log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
Y
Yibing Liu 已提交
164 165
            }

166 167 168 169
            for (int index=0; index<log_prob_idx.size(); index++) {
                int c = log_prob_idx[index].first;
                log_prob_type log_prob_c = log_prob_idx[index].second;
                log_prob_type log_probs_prev;
Y
Yibing Liu 已提交
170
                if (c == blank_id) {
171 172 173 174
                    log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                 log_probs_nb_prev[l]);
                    log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
                                                     log_prob_c+log_probs_prev);
Y
Yibing Liu 已提交
175 176 177
                } else {
                    std::string last_char = l.substr(l.size()-1, 1);
                    std::string new_char = vocabulary[c];
178
                    std::string l_plus = l + new_char;
Y
Yibing Liu 已提交
179 180

                    if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
181 182
                        log_probs_b_cur[l_plus] = -NUM_MAX;
                        log_probs_nb_cur[l_plus] = -NUM_MAX;
Y
Yibing Liu 已提交
183 184
                    }
                    if (last_char == new_char) {
185 186 187 188 189 190 191 192
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_prob_c+log_probs_b_prev[l]
                                            );
                        log_probs_nb_cur[l] = log_sum_exp(
                                                log_probs_nb_cur[l],
                                                log_prob_c+log_probs_nb_prev[l]
                                            );
Y
Yibing Liu 已提交
193
                    } else if (new_char == " ") {
194
                        float score = 0.0;
Y
Yibing Liu 已提交
195
                        if (ext_scorer != NULL && l.size() > 1) {
196
                            score = ext_scorer->get_score(l.substr(1), true);
Y
Yibing Liu 已提交
197
                        }
198 199 200 201 202 203
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                            log_probs_nb_cur[l_plus],
                                            score + log_prob_c + log_probs_prev
                                        );
Y
Yibing Liu 已提交
204
                    } else {
205 206 207 208 209 210
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                    log_probs_nb_cur[l_plus],
                                                    log_prob_c+log_probs_prev
                                                );
Y
Yibing Liu 已提交
211
                    }
212 213 214 215
                    prefix_set_next[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_probs_b_cur[l_plus]
                                            );
Y
Yibing Liu 已提交
216 217 218
                }
            }

219 220
            prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
                                             log_probs_nb_cur[l]);
Y
Yibing Liu 已提交
221 222
        }

223 224 225
        log_probs_b_prev = log_probs_b_cur;
        log_probs_nb_prev = log_probs_nb_cur;
        std::vector<std::pair<std::string, log_prob_type> >
226 227 228 229
                  prefix_vec_next(prefix_set_next.begin(),
                                  prefix_set_next.end());
        std::sort(prefix_vec_next.begin(),
                  prefix_vec_next.end(),
230 231 232 233 234 235 236
                  pair_comp_second_rev<std::string, log_prob_type>);
        int num_prefixes_next = prefix_vec_next.size();
        int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
        prefix_set_prev = std::map<std::string, log_prob_type> (
                                                   prefix_vec_next.begin(),
                                                   prefix_vec_next.begin() + k
                                                );
Y
Yibing Liu 已提交
237 238 239 240
    }

    // post processing
    std::vector<std::pair<double, std::string> > beam_result;
241 242 243 244
    for (std::map<std::string, log_prob_type>::iterator
         it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
        if (it->second > -NUM_MAX && it->first.size() > 1) {
            log_prob_type log_prob = it->second;
Y
Yibing Liu 已提交
245 246 247
            std::string sentence = it->first.substr(1);
            // scoring the last word
            if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
248 249 250 251 252
                log_prob = log_prob + ext_scorer->get_score(sentence, true);
            }
            if (log_prob > -NUM_MAX) {
                std::pair<double, std::string> cur_result(log_prob, sentence);
                beam_result.push_back(cur_result);
Y
Yibing Liu 已提交
253 254 255 256
            }
        }
    }
    // sort the result and return
257 258
    std::sort(beam_result.begin(), beam_result.end(),
              pair_comp_first_rev<double, std::string>);
Y
Yibing Liu 已提交
259 260
    return beam_result;
}