You need to sign in or sign up before continuing.
decoder_utils.cpp 5.9 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yibing Liu 已提交
15
#include "decoder_utils.h"
16

Y
Yibing Liu 已提交
17 18
#include <algorithm>
#include <cmath>
Y
Yibing Liu 已提交
19
#include <limits>
Y
Yibing Liu 已提交
20

21 22 23 24
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
    const std::vector<double> &prob_step,
    double cutoff_prob,
    size_t cutoff_top_n) {
25 26 27
    std::vector<std::pair<int, double>> prob_idx;
    for (size_t i = 0; i < prob_step.size(); ++i) {
        prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
28
    }
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    // pruning of vacobulary
    size_t cutoff_len = prob_step.size();
    if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
        std::sort(prob_idx.begin(),
                  prob_idx.end(),
                  pair_comp_second_rev<int, double>);
        if (cutoff_prob < 1.0) {
            double cum_prob = 0.0;
            cutoff_len = 0;
            for (size_t i = 0; i < prob_idx.size(); ++i) {
                cum_prob += prob_idx[i].second;
                cutoff_len += 1;
                if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
                    break;
            }
        }
        prob_idx = std::vector<std::pair<int, double>>(
            prob_idx.begin(), prob_idx.begin() + cutoff_len);
    }
    std::vector<std::pair<size_t, float>> log_prob_idx;
    for (size_t i = 0; i < cutoff_len; ++i) {
        log_prob_idx.push_back(std::pair<int, float>(
            prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
    }
    return log_prob_idx;
54 55 56 57 58 59 60
}


std::vector<std::pair<double, std::string>> get_beam_search_result(
    const std::vector<PathTrie *> &prefixes,
    const std::vector<std::string> &vocabulary,
    size_t beam_size) {
61 62 63 64 65 66
    // allow for the post processing
    std::vector<PathTrie *> space_prefixes;
    if (space_prefixes.empty()) {
        for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
            space_prefixes.push_back(prefixes[i]);
        }
67
    }
68 69 70 71 72 73 74 75 76 77 78 79 80 81

    std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
    std::vector<std::pair<double, std::string>> output_vecs;
    for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
        std::vector<int> output;
        space_prefixes[i]->get_path_vec(output);
        // convert index to string
        std::string output_str;
        for (size_t j = 0; j < output.size(); j++) {
            output_str += vocabulary[output[j]];
        }
        std::pair<double, std::string> output_pair(
            -space_prefixes[i]->approx_ctc, output_str);
        output_vecs.emplace_back(output_pair);
82 83
    }

84
    return output_vecs;
85 86
}

Y
Yibing Liu 已提交
87
size_t get_utf8_str_len(const std::string &str) {
88 89 90 91 92
    size_t str_len = 0;
    for (char c : str) {
        str_len += ((c & 0xc0) != 0x80);
    }
    return str_len;
93
}
94

Y
Yibing Liu 已提交
95
std::vector<std::string> split_utf8_str(const std::string &str) {
96 97 98 99 100 101 102 103 104 105 106 107 108
    std::vector<std::string> result;
    std::string out_str;

    for (char c : str) {
        if ((c & 0xc0) != 0x80)  // new UTF-8 character
        {
            if (!out_str.empty()) {
                result.push_back(out_str);
                out_str.clear();
            }
        }

        out_str.append(1, c);
109
    }
110 111
    result.push_back(out_str);
    return result;
112 113
}

Y
Yibing Liu 已提交
114 115
std::vector<std::string> split_str(const std::string &s,
                                   const std::string &delim) {
116 117 118 119 120 121 122 123 124 125 126 127 128 129
    std::vector<std::string> result;
    std::size_t start = 0, delim_len = delim.size();
    while (true) {
        std::size_t end = s.find(delim, start);
        if (end == std::string::npos) {
            if (start < s.size()) {
                result.push_back(s.substr(start));
            }
            break;
        }
        if (end > start) {
            result.push_back(s.substr(start, end - start));
        }
        start = end + delim_len;
Y
Yibing Liu 已提交
130
    }
131
    return result;
Y
Yibing Liu 已提交
132 133
}

Y
Yibing Liu 已提交
134
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
135 136 137 138 139 140
    if (x->score == y->score) {
        if (x->character == y->character) {
            return false;
        } else {
            return (x->character < y->character);
        }
141
    } else {
142
        return x->score > y->score;
143
    }
144
}
145

Y
Yibing Liu 已提交
146 147
void add_word_to_fst(const std::vector<int> &word,
                     fst::StdVectorFst *dictionary) {
148 149 150 151 152 153 154 155 156 157 158 159 160
    if (dictionary->NumStates() == 0) {
        fst::StdVectorFst::StateId start = dictionary->AddState();
        assert(start == 0);
        dictionary->SetStart(start);
    }
    fst::StdVectorFst::StateId src = dictionary->Start();
    fst::StdVectorFst::StateId dst;
    for (auto c : word) {
        dst = dictionary->AddState();
        dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
        src = dst;
    }
    dictionary->SetFinal(dst, fst::StdArc::Weight::One());
161
}
162

Y
Yibing Liu 已提交
163
bool add_word_to_dictionary(
Y
Yibing Liu 已提交
164 165
    const std::string &word,
    const std::unordered_map<std::string, int> &char_map,
Y
Yibing Liu 已提交
166 167
    bool add_space,
    int SPACE_ID,
Y
Yibing Liu 已提交
168
    fst::StdVectorFst *dictionary) {
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
    auto characters = split_utf8_str(word);

    std::vector<int> int_word;

    for (auto &c : characters) {
        if (c == " ") {
            int_word.push_back(SPACE_ID);
        } else {
            auto int_c = char_map.find(c);
            if (int_c != char_map.end()) {
                int_word.push_back(int_c->second);
            } else {
                return false;  // return without adding
            }
        }
184 185
    }

186 187 188
    if (add_space) {
        int_word.push_back(SPACE_ID);
    }
189

190 191
    add_word_to_fst(int_word, dictionary);
    return true;  // return with successful adding
192
}