decoder.py 9.5 KB
Newer Older
1 2 3 4
"""
    CTC-like decoder utilitis.
"""

5
import os
6 7
from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
8 9
import copy
import kenlm
10
import multiprocessing
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
43
class Scorer(object):
44
    """
Y
Yibing Liu 已提交
45 46
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
47

Y
Yibing Liu 已提交
48 49 50 51 52 53 54 55 56 57 58
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
Y
Yibing Liu 已提交
59 60
        if not os.path.isfile(model_path):
            raise IOError("Invaid language model path: %s" % model_path)
Y
Yibing Liu 已提交
61 62
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
63 64 65 66 67 68
    # n-gram language model scoring
    def language_model_score(self, sentence):
        #log prob of last word
        log_cond_prob = list(
            self._language_model.full_scores(sentence, eos=False))[-1][0]
        return np.power(10, log_cond_prob)
Y
Yibing Liu 已提交
69

Y
Yibing Liu 已提交
70
    # word insertion term
Y
Yibing Liu 已提交
71 72 73 74 75
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
76
    def __call__(self, sentence):
Y
Yibing Liu 已提交
77
        lm = self.language_model_score(sentence)
Y
Yibing Liu 已提交
78
        word_cnt = self.word_count(sentence)
Y
Yibing Liu 已提交
79
        score = np.power(lm, self._alpha) \
Y
Yibing Liu 已提交
80
                * np.power(word_cnt, self._beta)
Y
Yibing Liu 已提交
81 82 83 84 85 86
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
87
                            blank_id=0,
Y
Yibing Liu 已提交
88
                            ext_scoring_func=None,
89
                            nproc=False):
Y
Yibing Liu 已提交
90 91 92 93 94 95 96
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
    the order of probabilities. The implementation is based on Prefix Beam
    Search(https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned, need to be verified.

Y
Yibing Liu 已提交
97
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
98 99 100 101 102
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
103 104
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
105 106 107 108 109 110
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
111 112
    :param nproc: Whether the decoder used in multiprocesses.
    :type nproc: bool
Y
Yibing Liu 已提交
113 114 115 116
    :return: Decoding log probability and result string.
    :rtype: list

    '''
Y
Yibing Liu 已提交
117
    # dimension check
118 119 120
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
            raise ValueError("probs dimension mismatchedd with vocabulary")
Y
Yibing Liu 已提交
121
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
122

Y
Yibing Liu 已提交
123
    # blank_id check
Y
Yibing Liu 已提交
124 125 126 127
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

128 129 130 131 132 133
    # If the decoder called in the multiprocesses, then use the global scorer
    # instantiated in ctc_beam_search_decoder_nproc().
    if nproc is True:
        global ext_nproc_scorer
        ext_scoring_func = ext_nproc_scorer

Y
Yibing Liu 已提交
134 135
    ## initialize
    # the set containing selected prefixes
136 137
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
138 139

    ## extend prefix in loop
Y
Yibing Liu 已提交
140
    for time_step in range(num_time_steps):
Y
Yibing Liu 已提交
141 142 143 144 145 146 147 148 149 150 151
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
        for l in prefix_set_prev:
            prob = probs_seq[time_step]
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

            # extend prefix by travering vocabulary
            for c in range(0, probs_dim):
                if c == blank_id:
Y
Yibing Liu 已提交
152 153
                    probs_b_cur[l] += prob[c] * (
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
154
                else:
155 156 157
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
158 159 160
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

161
                    if new_char == last_char:
Y
Yibing Liu 已提交
162 163
                        probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
                        probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
164 165
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
166 167
                            score = 1.0
                        else:
168
                            prefix = l[1:]
Y
Yibing Liu 已提交
169
                            score = ext_scoring_func(prefix)
Y
Yibing Liu 已提交
170
                        probs_nb_cur[l_plus] += score * prob[c] * (
Y
Yibing Liu 已提交
171
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
172 173
                    else:
                        probs_nb_cur[l_plus] += prob[c] * (
Y
Yibing Liu 已提交
174
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
175 176 177 178 179 180
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
181
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
182 183 184 185 186 187 188 189 190 191 192

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
        if prob > 0.0:
193
            result = seq[1:]
Y
Yibing Liu 已提交
194 195 196 197 198 199
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
200 201 202 203 204 205


def ctc_beam_search_decoder_nproc(probs_split,
                                  beam_size,
                                  vocabulary,
                                  blank_id=0,
206
                                  ext_scoring_func=None,
207 208 209 210
                                  num_processes=None):
    '''
    Beam search decoder using multiple processes.

Y
Yibing Liu 已提交
211
    :param probs_seq: 3-D list with length batch_size, each element
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
                      is a 2-D list of  probabilities can be used by
                      ctc_beam_search_decoder.
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :param num_processes: Number of processes, default None, equal to the
                 number of CPUs.
    :type num_processes: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''

    if num_processes is None:
        num_processes = multiprocessing.cpu_count()
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

238 239 240 241 242
    # use global variable to pass the externnal scorer to beam search decoder
    global ext_nproc_scorer
    ext_nproc_scorer = ext_scoring_func
    nproc = True

243 244 245
    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
246
        args = (probs_list, beam_size, vocabulary, blank_id, None, nproc)
247 248 249 250 251 252 253 254
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
    beam_search_results = []
    for result in results:
        beam_search_results.append(result.get())
    return beam_search_results