ctc_decoders.cpp 10.7 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
#include <iostream>
#include <map>
#include <algorithm>
#include <utility>
#include <cmath>
6
#include <limits>
7
#include "fst/fstlib.h"
8
#include "ctc_decoders.h"
Y
Yibing Liu 已提交
9
#include "decoder_utils.h"
10
#include "path_trie.h"
11
#include "ThreadPool.h"
Y
Yibing Liu 已提交
12

13
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
14 15
                                  std::vector<std::string> vocabulary)
{
16 17 18 19
    // dimension check
    int num_time_steps = probs_seq.size();
    for (int i=0; i<num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size()+1) {
Y
Yibing Liu 已提交
20 21
            std::cout << "The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
22 23 24 25 26 27 28 29 30
            exit(1);
        }
    }

    int blank_id = vocabulary.size();

    std::vector<int> max_idx_vec;
    double max_prob = 0.0;
    int max_idx = 0;
Y
Yibing Liu 已提交
31 32
    for (int i = 0; i < num_time_steps; i++) {
        for (int j = 0; j < probs_seq[i].size(); j++) {
33 34 35 36 37 38 39 40 41 42 43
            if (max_prob < probs_seq[i][j]) {
                max_idx = j;
                max_prob = probs_seq[i][j];
            }
        }
        max_idx_vec.push_back(max_idx);
        max_prob = 0.0;
        max_idx = 0;
    }

    std::vector<int> idx_vec;
Y
Yibing Liu 已提交
44 45
    for (int i = 0; i < max_idx_vec.size(); i++) {
        if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) {
46 47 48 49 50
            idx_vec.push_back(max_idx_vec[i]);
        }
    }

    std::string best_path_result;
Y
Yibing Liu 已提交
51
    for (int i = 0; i < idx_vec.size(); i++) {
52
        if (idx_vec[i] != blank_id) {
53
            best_path_result += vocabulary[idx_vec[i]];
54 55 56 57 58
        }
    }
    return best_path_result;
}

Y
Yibing Liu 已提交
59
std::vector<std::pair<double, std::string> >
60 61 62 63 64
    ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
                            int beam_size,
                            std::vector<std::string> vocabulary,
                            int blank_id,
                            double cutoff_prob,
65 66
                            Scorer *ext_scorer)
{
67
    // dimension check
Y
Yibing Liu 已提交
68
    int num_time_steps = probs_seq.size();
Y
Yibing Liu 已提交
69 70
    for (int i = 0; i < num_time_steps; i++) {
        if (probs_seq[i].size() != vocabulary.size() + 1) {
71 72
            std::cout << " The shape of probs_seq does not match"
                      << " with the shape of the vocabulary!" << std::endl;
73 74 75 76 77 78
            exit(1);
        }
    }

    // blank_id check
    if (blank_id > vocabulary.size()) {
79
        std::cout << " Invalid blank_id! " << std::endl;
80 81
        exit(1);
    }
Y
Yibing Liu 已提交
82 83

    // assign space ID
84 85 86
    std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
                                                  vocabulary.end(), " ");
    int space_id = it - vocabulary.begin();
Y
Yibing Liu 已提交
87
    // if no space in vocabulary
Y
Yibing Liu 已提交
88
    if(space_id >= vocabulary.size()) {
Y
Yibing Liu 已提交
89
        space_id = -2;
Y
Yibing Liu 已提交
90
    }
Y
Yibing Liu 已提交
91

92 93
    // init
    PathTrie root;
Y
Yibing Liu 已提交
94
    root._score = root._log_prob_b_prev = 0.0;
95 96 97 98
    std::vector<PathTrie*> prefixes;
    prefixes.push_back(&root);

    if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) {
99
        if (ext_scorer->_dictionary == nullptr) {
100
        // TODO: init dictionary
101 102 103
            ext_scorer->set_char_map(vocabulary);
            // add_space should be true?
            ext_scorer->fill_dictionary(true);
104
        }
105
        auto fst_dict = static_cast<fst::StdVectorFst*>(ext_scorer->_dictionary);
106 107 108 109 110
        fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
        root.set_dictionary(dict_ptr);
        auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
        root.set_matcher(matcher);
    }
Y
Yibing Liu 已提交
111

112 113
    for (int time_step = 0; time_step < num_time_steps; time_step++) {
        std::vector<double> prob = probs_seq[time_step];
Y
Yibing Liu 已提交
114 115 116 117
        std::vector<std::pair<int, double> > prob_idx;
        for (int i=0; i<prob.size(); i++) {
            prob_idx.push_back(std::pair<int, double>(i, prob[i]));
        }
118

Y
Yibing Liu 已提交
119
        // pruning of vacobulary
120
        int cutoff_len = prob.size();
Y
Yibing Liu 已提交
121
        if (cutoff_prob < 1.0) {
122 123
            std::sort(prob_idx.begin(),
                      prob_idx.end(),
124
                      pair_comp_second_rev<int, double>);
125 126
            double cum_prob = 0.0;
            cutoff_len = 0;
Y
Yibing Liu 已提交
127 128 129 130 131
            for (int i=0; i<prob_idx.size(); i++) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob) break;
            }
132
            prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
133
                            prob_idx.begin() + cutoff_len);
Y
Yibing Liu 已提交
134
        }
135

Y
Yibing Liu 已提交
136 137 138 139
        std::vector<std::pair<int, float> > log_prob_idx;
        for (int i = 0; i < cutoff_len; i++) {
            log_prob_idx.push_back(std::pair<int, float>
                  (prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
140 141
        }

142 143 144
        // loop over chars
        for (int index = 0; index < log_prob_idx.size(); index++) {
            auto c = log_prob_idx[index].first;
Y
Yibing Liu 已提交
145 146
            float log_prob_c = log_prob_idx[index].second;
            //float log_probs_prev;
Y
Yibing Liu 已提交
147

148 149 150
            for (int i = 0; i < prefixes.size() && i<beam_size; i++) {
                auto prefix = prefixes[i];
                // blank
Y
Yibing Liu 已提交
151
                if (c == blank_id) {
152 153 154 155 156 157 158 159 160
                    prefix->_log_prob_b_cur = log_sum_exp(
                                               prefix->_log_prob_b_cur,
                                               log_prob_c + prefix->_score);
                    continue;
                }
                // repeated character
                if (c == prefix->_character) {
                    prefix->_log_prob_nb_cur = log_sum_exp(
                        prefix->_log_prob_nb_cur,
Y
Yibing Liu 已提交
161
                        log_prob_c + prefix->_log_prob_nb_prev);
162 163 164 165 166
                }
                // get new prefix
                auto prefix_new = prefix->get_path_trie(c);

                if (prefix_new != nullptr) {
Y
Yibing Liu 已提交
167
                    float log_p = -NUM_FLT_INF;
168 169

                    if (c == prefix->_character
Y
Yibing Liu 已提交
170
                        && prefix->_log_prob_b_prev > -NUM_FLT_INF) {
171 172 173
                        log_p = log_prob_c + prefix->_log_prob_b_prev;
                    } else if (c != prefix->_character) {
                        log_p = log_prob_c + prefix->_score;
Y
Yibing Liu 已提交
174
                    }
175 176 177 178 179 180 181 182 183 184 185

                    // language model scoring
                    if (ext_scorer != nullptr &&
                        (c == space_id || ext_scorer->is_character_based()) ) {
                        PathTrie *prefix_to_score = nullptr;

                        // don't score the space
                        if (ext_scorer->is_character_based()) {
                            prefix_to_score = prefix_new;
                        } else {
                            prefix_to_score = prefix;
Y
Yibing Liu 已提交
186
                        }
187 188 189 190 191 192 193 194 195

                        double score = 0.0;
                        std::vector<std::string> ngram;
                        ngram = ext_scorer->make_ngram(prefix_to_score);
                        score = ext_scorer->get_log_cond_prob(ngram) *
                                ext_scorer->alpha;

                        log_p += score;
                        log_p += ext_scorer->beta;
Y
Yibing Liu 已提交
196
                    }
197 198
                    prefix_new->_log_prob_nb_cur = log_sum_exp(
                                        prefix_new->_log_prob_nb_cur, log_p);
Y
Yibing Liu 已提交
199 200 201
                }
            }

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
        } // end of loop over chars

        prefixes.clear();
        // update log probabilities
        root.iterate_to_vec(prefixes);

        // sort prefixes by score
        if (prefixes.size() >= beam_size) {
            std::nth_element(prefixes.begin(),
                    prefixes.begin() + beam_size,
                    prefixes.end(),
                    prefix_compare);

            for (size_t i = beam_size; i < prefixes.size(); i++) {
                prefixes[i]->remove();
            }
        }
    }

    for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
        double approx_ctc = prefixes[i]->_score;

        // remove word insert:
        std::vector<int> output;
        prefixes[i]->get_path_vec(output);
        size_t prefix_length = output.size();
        // remove language model weight:
        if (ext_scorer != nullptr) {
           // auto words = split_labels(output);
           // approx_ctc = approx_ctc - path_length * ext_scorer->beta;
           // approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha;
Y
Yibing Liu 已提交
233 234
        }

235
        prefixes[i]->_approx_ctc = approx_ctc;
Y
Yibing Liu 已提交
236 237
    }

238 239 240 241 242
    // allow for the post processing
    std::vector<PathTrie*> space_prefixes;
    if (space_prefixes.empty()) {
        for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) {
            space_prefixes.push_back(prefixes[i]);
Y
Yibing Liu 已提交
243 244
        }
    }
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
    std::vector<std::pair<double, std::string> > output_vecs;
    for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
        std::vector<int> output;
        space_prefixes[i]->get_path_vec(output);
        // convert index to string
        std::string output_str;
        for (int j = 0; j < output.size(); j++) {
            output_str += vocabulary[output[j]];
        }
        std::pair<double, std::string> output_pair(space_prefixes[i]->_score,
                                                   output_str);
        output_vecs.emplace_back(
            output_pair
        );
    }

    return output_vecs;
 }
265 266


Y
Yibing Liu 已提交
267
std::vector<std::vector<std::pair<double, std::string> > >
268 269 270 271 272 273 274 275
    ctc_beam_search_decoder_batch(
                std::vector<std::vector<std::vector<double>>> probs_split,
                int beam_size,
                std::vector<std::string> vocabulary,
                int blank_id,
                int num_processes,
                double cutoff_prob,
                Scorer *ext_scorer
276
                ) {
277 278 279 280 281 282 283 284
    if (num_processes <= 0) {
        std::cout << "num_processes must be nonnegative!" << std::endl;
        exit(1);
    }
    // thread pool
    ThreadPool pool(num_processes);
    // number of samples
    int batch_size = probs_split.size();
285
    // dictionary init
Y
Yibing Liu 已提交
286 287 288 289 290 291
    if ( ext_scorer != nullptr
         && !ext_scorer->is_character_based()
         && ext_scorer->_dictionary == nullptr) {
        // init dictionary
        ext_scorer->set_char_map(vocabulary);
        ext_scorer->fill_dictionary(true);
292
    }
293 294 295 296 297 298 299 300 301
    // enqueue the tasks of decoding
    std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
    for (int i = 0; i < batch_size; i++) {
        res.emplace_back(
                pool.enqueue(ctc_beam_search_decoder, probs_split[i],
                    beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer)
            );
    }
    // get decoding results
Y
Yibing Liu 已提交
302
    std::vector<std::vector<std::pair<double, std::string> > > batch_results;
303 304 305 306 307
    for (int i = 0; i < batch_size; i++) {
        batch_results.emplace_back(res[i].get());
    }
    return batch_results;
}