decoder.py 11.1 KB
Newer Older
1 2 3 4
"""
    CTC-like decoder utilitis.
"""

5
import os
6 7
from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
8
import kenlm
9
import multiprocessing
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
42
class Scorer(object):
43
    """
Y
Yibing Liu 已提交
44 45
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
46

Y
Yibing Liu 已提交
47 48 49 50 51 52 53 54 55 56 57
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
Y
Yibing Liu 已提交
58 59
        if not os.path.isfile(model_path):
            raise IOError("Invaid language model path: %s" % model_path)
Y
Yibing Liu 已提交
60 61
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
62 63 64 65 66 67
    # n-gram language model scoring
    def language_model_score(self, sentence):
        #log prob of last word
        log_cond_prob = list(
            self._language_model.full_scores(sentence, eos=False))[-1][0]
        return np.power(10, log_cond_prob)
Y
Yibing Liu 已提交
68

Y
Yibing Liu 已提交
69
    # word insertion term
Y
Yibing Liu 已提交
70 71 72 73 74
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
75 76 77 78 79 80 81 82 83 84 85
    def __call__(self, sentence, log=False):
        """
        Evaluation function

        :param sentence: The input sentence for evalutation
        :type sentence: basestring
        :param log: Whether return the score in log representation.
        :type log: bool
        :return: Evaluation score, in the decimal or log.
        :rtype: float
        """
Y
Yibing Liu 已提交
86
        lm = self.language_model_score(sentence)
Y
Yibing Liu 已提交
87
        word_cnt = self.word_count(sentence)
88 89 90 91 92 93
        if log == False:
            score = np.power(lm, self._alpha) \
                    * np.power(word_cnt, self._beta)
        else:
            score = self._alpha * np.log(lm) \
                    + self._beta * np.log(word_cnt)
Y
Yibing Liu 已提交
94 95 96 97 98 99
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
100
                            blank_id=0,
101
                            cutoff_prob=1.0,
Y
Yibing Liu 已提交
102
                            ext_scoring_func=None,
103
                            nproc=False):
Y
Yibing Liu 已提交
104 105 106
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
107 108
    the descending order of probabilities. The implementation is based on Prefix
    Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
Y
Yibing Liu 已提交
109 110
    redesigned, need to be verified.

Y
Yibing Liu 已提交
111
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
112 113 114 115 116
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
117 118
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
119 120 121 122 123
    :param blank_id: ID of blank, default 0.
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
Y
Yibing Liu 已提交
124 125 126 127
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
128 129
    :param nproc: Whether the decoder used in multiprocesses.
    :type nproc: bool
130
    :return: Decoding log probabilities and result sentences in descending order.
Y
Yibing Liu 已提交
131 132 133
    :rtype: list

    '''
Y
Yibing Liu 已提交
134
    # dimension check
135 136
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
137
            raise ValueError("probs dimension mismatched with vocabulary")
Y
Yibing Liu 已提交
138
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
139

Y
Yibing Liu 已提交
140
    # blank_id check
Y
Yibing Liu 已提交
141 142 143 144
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

145 146 147 148 149 150
    # If the decoder called in the multiprocesses, then use the global scorer
    # instantiated in ctc_beam_search_decoder_nproc().
    if nproc is True:
        global ext_nproc_scorer
        ext_scoring_func = ext_nproc_scorer

Y
Yibing Liu 已提交
151 152
    ## initialize
    # the set containing selected prefixes
153 154
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
155 156

    ## extend prefix in loop
157
    for time_step in xrange(num_time_steps):
Y
Yibing Liu 已提交
158 159 160
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        prob = probs_seq[time_step]
        prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
        cutoff_len = len(prob_idx)
        #If pruning is enabled
        if (cutoff_prob < 1.0):
            prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
            cutoff_len = 0
            cum_prob = 0.0
            for i in xrange(len(prob_idx)):
                cum_prob += prob_idx[i][1]
                cutoff_len += 1
                if cum_prob >= cutoff_prob:
                    break
            prob_idx = prob_idx[0:cutoff_len]

Y
Yibing Liu 已提交
176 177 178 179
        for l in prefix_set_prev:
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

180 181 182 183
            # extend prefix by travering prob_idx
            for index in xrange(cutoff_len):
                c, prob_c = prob_idx[index][0], prob_idx[index][1]

Y
Yibing Liu 已提交
184
                if c == blank_id:
185
                    probs_b_cur[l] += prob_c * (
Y
Yibing Liu 已提交
186
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
187
                else:
188 189 190
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
191 192 193
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

194
                    if new_char == last_char:
195 196
                        probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
                        probs_nb_cur[l] += prob_c * probs_nb_prev[l]
197 198
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
199 200
                            score = 1.0
                        else:
201
                            prefix = l[1:]
Y
Yibing Liu 已提交
202
                            score = ext_scoring_func(prefix)
203
                        probs_nb_cur[l_plus] += score * prob_c * (
Y
Yibing Liu 已提交
204
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
205
                    else:
206
                        probs_nb_cur[l_plus] += prob_c * (
Y
Yibing Liu 已提交
207
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
208 209 210 211 212 213
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
214
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
215 216 217 218 219 220 221 222 223 224

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
225
        if prob > 0.0 and len(seq) > 1:
226
            result = seq[1:]
227 228 229
            # score last word by external scorer
            if (ext_scoring_func is not None) and (result[-1] != ' '):
                prob = prob * ext_scoring_func(result)
Y
Yibing Liu 已提交
230 231 232 233 234 235
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
236 237 238 239 240 241


def ctc_beam_search_decoder_nproc(probs_split,
                                  beam_size,
                                  vocabulary,
                                  blank_id=0,
242
                                  cutoff_prob=1.0,
243
                                  ext_scoring_func=None,
244 245 246 247
                                  num_processes=None):
    '''
    Beam search decoder using multiple processes.

Y
Yibing Liu 已提交
248
    :param probs_seq: 3-D list with length batch_size, each element
249 250 251 252 253 254 255
                      is a 2-D list of  probabilities can be used by
                      ctc_beam_search_decoder.
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
256 257 258 259 260
    :param blank_id: ID of blank, default 0.
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 0, no pruning.
    :type cutoff_prob: float
261 262 263 264 265 266 267
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param num_processes: Number of processes, default None, equal to the
                 number of CPUs.
    :type num_processes: int
268
    :return: Decoding log probabilities and result sentences in descending order.
269 270 271 272 273 274 275 276 277
    :rtype: list

    '''

    if num_processes is None:
        num_processes = multiprocessing.cpu_count()
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

278 279 280 281 282
    # use global variable to pass the externnal scorer to beam search decoder
    global ext_nproc_scorer
    ext_nproc_scorer = ext_scoring_func
    nproc = True

283 284 285
    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
286 287
        args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
                nproc)
288 289 290 291 292 293 294 295
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
    beam_search_results = []
    for result in results:
        beam_search_results.append(result.get())
    return beam_search_results