ctc_decoders.cpp 10.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
6
#include <limits>
7
#include "ctc_decoders.h"
Y
Yibing Liu 已提交
8
#include "decoder_utils.h"
Y
Yibing Liu 已提交
9

10
typedef double log_prob_type;
11

Y
Yibing Liu 已提交
12

Y
Yibing Liu 已提交
13
template <typename T1, typename T2>
14 15
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
16 17 18 19
    return a.first > b.first;
}

template <typename T1, typename T2>
20 21
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{
Y
Yibing Liu 已提交
22 23 24
    return a.second > b.second;
}

25 26 27 28
template <typename T>
T log_sum_exp(T x, T y)
{
    static T num_min = -std::numeric_limits<T>::max();
29 30
    if (x <= num_min) return y;
    if (y <= num_min) return x;
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    T xmax = std::max(x, y);
    return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}

std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
                                  std::vector<std::string> vocabulary) {
    // dimension check
    int num_time_steps = probs_seq.size();
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
            std::cout<<"The shape of probs_seq does not match"
                     <<" with the shape of the vocabulary!"<<std::endl;
            exit(1);
        }
    }

    int blank_id = vocabulary.size();

    std::vector<int> max_idx_vec;
    double max_prob = 0.0;
    int max_idx = 0;
    for (int i=0; i<num_time_steps; i++) {
        for (int j=0; j<probs_seq[i].size(); j++) {
            if (max_prob < probs_seq[i][j]) {
                max_idx = j;
                max_prob = probs_seq[i][j];
            }
        }
        max_idx_vec.push_back(max_idx);
        max_prob = 0.0;
        max_idx = 0;
    }

    std::vector<int> idx_vec;
    for (int i=0; i<max_idx_vec.size(); i++) {
        if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
            idx_vec.push_back(max_idx_vec[i]);
        }
    }

    std::string best_path_result;
    for (int i=0; i<idx_vec.size(); i++) {
        if (idx_vec[i] != blank_id) {
74
            best_path_result += vocabulary[idx_vec[i]];
75 76 77 78 79
        }
    }
    return best_path_result;
}

Y
Yibing Liu 已提交
80
std::vector<std::pair<double, std::string> >
81 82 83 84 85
    ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
                            int beam_size,
                            std::vector<std::string> vocabulary,
                            int blank_id,
                            double cutoff_prob,
Y
Yibing Liu 已提交
86
                            Scorer *ext_scorer,
87 88
                            bool nproc) {
    // dimension check
Y
Yibing Liu 已提交
89
    int num_time_steps = probs_seq.size();
90 91
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
92 93
            std::cout << " The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
94 95 96 97 98 99
            exit(1);
        }
    }

    // blank_id check
    if (blank_id > vocabulary.size()) {
100
        std::cout << " Invalid blank_id! " << std::endl;
101 102
        exit(1);
    }
Y
Yibing Liu 已提交
103 104

    // assign space ID
105 106 107
    std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
                                                  vocabulary.end(), " ");
    int space_id = it - vocabulary.begin();
Y
Yibing Liu 已提交
108
    if(space_id >= vocabulary.size()) {
109
        std::cout << " The character space is not in the vocabulary!"<<std::endl;
Y
Yibing Liu 已提交
110
        exit(1);
Y
Yibing Liu 已提交
111
    }
Y
Yibing Liu 已提交
112

Y
Yibing Liu 已提交
113 114
    // initialize
    // two sets containing selected and candidate prefixes respectively
115
    std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
Y
Yibing Liu 已提交
116
    // probability of prefixes ending with blank and non-blank
117 118 119 120 121 122 123
    std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
    std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;

    static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
    prefix_set_prev["\t"] = 0.0;
    log_probs_b_prev["\t"] = 0.0;
    log_probs_nb_prev["\t"] = -NUM_MAX;
Y
Yibing Liu 已提交
124

Y
Yibing Liu 已提交
125 126
    for (int time_step=0; time_step<num_time_steps; time_step++) {
        prefix_set_next.clear();
127 128
        log_probs_b_cur.clear();
        log_probs_nb_cur.clear();
Y
Yibing Liu 已提交
129 130 131 132 133 134
        std::vector<double> prob = probs_seq[time_step];

        std::vector<std::pair<int, double> > prob_idx;
        for (int i=0; i<prob.size(); i++) {
            prob_idx.push_back(std::pair<int, double>(i, prob[i]));
        }
135

Y
Yibing Liu 已提交
136
        // pruning of vacobulary
137
        int cutoff_len = prob.size();
Y
Yibing Liu 已提交
138
        if (cutoff_prob < 1.0) {
139 140
            std::sort(prob_idx.begin(),
                      prob_idx.end(),
141
                      pair_comp_second_rev<int, double>);
142 143
            double cum_prob = 0.0;
            cutoff_len = 0;
Y
Yibing Liu 已提交
144 145 146 147 148
            for (int i=0; i<prob_idx.size(); i++) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob) break;
            }
149
            prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
150
                            prob_idx.begin() + cutoff_len);
Y
Yibing Liu 已提交
151
        }
152 153 154 155 156 157 158

        std::vector<std::pair<int, log_prob_type> > log_prob_idx;
        for (int i=0; i<cutoff_len; i++) {
            log_prob_idx.push_back(std::pair<int, log_prob_type>
                        (prob_idx[i].first, log(prob_idx[i].second)));
        }

Y
Yibing Liu 已提交
159
        // extend prefix
160 161
        for (std::map<std::string, log_prob_type>::iterator
             it = prefix_set_prev.begin();
Y
Yibing Liu 已提交
162 163 164
            it != prefix_set_prev.end(); it++) {
            std::string l = it->first;
            if( prefix_set_next.find(l) == prefix_set_next.end()) {
165
                log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
Y
Yibing Liu 已提交
166 167
            }

168 169 170 171
            for (int index=0; index<log_prob_idx.size(); index++) {
                int c = log_prob_idx[index].first;
                log_prob_type log_prob_c = log_prob_idx[index].second;
                log_prob_type log_probs_prev;
Y
Yibing Liu 已提交
172
                if (c == blank_id) {
173 174 175 176
                    log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                 log_probs_nb_prev[l]);
                    log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
                                                     log_prob_c+log_probs_prev);
Y
Yibing Liu 已提交
177 178 179
                } else {
                    std::string last_char = l.substr(l.size()-1, 1);
                    std::string new_char = vocabulary[c];
180
                    std::string l_plus = l + new_char;
Y
Yibing Liu 已提交
181 182

                    if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
183 184
                        log_probs_b_cur[l_plus] = -NUM_MAX;
                        log_probs_nb_cur[l_plus] = -NUM_MAX;
Y
Yibing Liu 已提交
185 186
                    }
                    if (last_char == new_char) {
187 188 189 190 191 192 193 194
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_prob_c+log_probs_b_prev[l]
                                            );
                        log_probs_nb_cur[l] = log_sum_exp(
                                                log_probs_nb_cur[l],
                                                log_prob_c+log_probs_nb_prev[l]
                                            );
Y
Yibing Liu 已提交
195
                    } else if (new_char == " ") {
196
                        float score = 0.0;
Y
Yibing Liu 已提交
197
                        if (ext_scorer != NULL && l.size() > 1) {
198
                            score = ext_scorer->get_score(l.substr(1), true);
Y
Yibing Liu 已提交
199
                        }
200 201 202 203 204 205
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                            log_probs_nb_cur[l_plus],
                                            score + log_prob_c + log_probs_prev
                                        );
Y
Yibing Liu 已提交
206
                    } else {
207 208 209 210 211 212
                        log_probs_prev = log_sum_exp(log_probs_b_prev[l],
                                                     log_probs_nb_prev[l]);
                        log_probs_nb_cur[l_plus] = log_sum_exp(
                                                    log_probs_nb_cur[l_plus],
                                                    log_prob_c+log_probs_prev
                                                );
Y
Yibing Liu 已提交
213
                    }
214 215 216 217
                    prefix_set_next[l_plus] = log_sum_exp(
                                                log_probs_nb_cur[l_plus],
                                                log_probs_b_cur[l_plus]
                                            );
Y
Yibing Liu 已提交
218 219 220
                }
            }

221 222
            prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
                                             log_probs_nb_cur[l]);
Y
Yibing Liu 已提交
223 224
        }

225 226 227
        log_probs_b_prev = log_probs_b_cur;
        log_probs_nb_prev = log_probs_nb_cur;
        std::vector<std::pair<std::string, log_prob_type> >
228 229 230 231
                  prefix_vec_next(prefix_set_next.begin(),
                                  prefix_set_next.end());
        std::sort(prefix_vec_next.begin(),
                  prefix_vec_next.end(),
232 233 234 235 236 237 238
                  pair_comp_second_rev<std::string, log_prob_type>);
        int num_prefixes_next = prefix_vec_next.size();
        int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
        prefix_set_prev = std::map<std::string, log_prob_type> (
                                                   prefix_vec_next.begin(),
                                                   prefix_vec_next.begin() + k
                                                );
Y
Yibing Liu 已提交
239 240 241 242
    }

    // post processing
    std::vector<std::pair<double, std::string> > beam_result;
243 244 245 246
    for (std::map<std::string, log_prob_type>::iterator
         it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
        if (it->second > -NUM_MAX && it->first.size() > 1) {
            log_prob_type log_prob = it->second;
Y
Yibing Liu 已提交
247 248 249
            std::string sentence = it->first.substr(1);
            // scoring the last word
            if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
250 251 252 253 254
                log_prob = log_prob + ext_scorer->get_score(sentence, true);
            }
            if (log_prob > -NUM_MAX) {
                std::pair<double, std::string> cur_result(log_prob, sentence);
                beam_result.push_back(cur_result);
Y
Yibing Liu 已提交
255 256 257 258
            }
        }
    }
    // sort the result and return
259 260
    std::sort(beam_result.begin(), beam_result.end(),
              pair_comp_first_rev<double, std::string>);
Y
Yibing Liu 已提交
261 262
    return beam_result;
}