decoder.py 11.2 KB
Newer Older
1 2 3 4
"""Contains various CTC decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5

6
import os
7 8
from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
9
import kenlm
10
import multiprocessing
11 12 13


def ctc_best_path_decode(probs_seq, vocabulary):
14
    """Best path decoding, also called argmax decoding or greedy decoding.
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
42
class Scorer(object):
Y
Yibing Liu 已提交
43
    """External defined scorer to evaluate a sentence in beam search
Y
Yibing Liu 已提交
44
               decoding, consisting of language model and word count.
45

Y
Yibing Liu 已提交
46 47 48 49 50 51 52 53 54 55 56
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
Y
Yibing Liu 已提交
57 58
        if not os.path.isfile(model_path):
            raise IOError("Invaid language model path: %s" % model_path)
Y
Yibing Liu 已提交
59 60
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
61 62 63 64 65 66
    # n-gram language model scoring
    def language_model_score(self, sentence):
        #log prob of last word
        log_cond_prob = list(
            self._language_model.full_scores(sentence, eos=False))[-1][0]
        return np.power(10, log_cond_prob)
Y
Yibing Liu 已提交
67

Y
Yibing Liu 已提交
68
    # word insertion term
Y
Yibing Liu 已提交
69 70 71 72 73
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
74
    def __call__(self, sentence, log=False):
Y
Yibing Liu 已提交
75
        """Evaluation function, gathering all the scores.
76 77 78 79 80 81 82 83

        :param sentence: The input sentence for evalutation
        :type sentence: basestring
        :param log: Whether return the score in log representation.
        :type log: bool
        :return: Evaluation score, in the decimal or log.
        :rtype: float
        """
Y
Yibing Liu 已提交
84
        lm = self.language_model_score(sentence)
Y
Yibing Liu 已提交
85
        word_cnt = self.word_count(sentence)
86 87 88 89 90 91
        if log == False:
            score = np.power(lm, self._alpha) \
                    * np.power(word_cnt, self._beta)
        else:
            score = self._alpha * np.log(lm) \
                    + self._beta * np.log(word_cnt)
Y
Yibing Liu 已提交
92 93 94 95 96 97
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
98
                            blank_id=0,
99
                            cutoff_prob=1.0,
Y
Yibing Liu 已提交
100
                            ext_scoring_func=None,
101
                            nproc=False):
Y
Yibing Liu 已提交
102
    '''Beam search decoder for CTC-trained network, using beam search with width
Y
Yibing Liu 已提交
103
    beam_size to find many paths to one label, return  beam_size labels in
104 105
    the descending order of probabilities. The implementation is based on Prefix
    Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
Y
Yibing Liu 已提交
106 107
    redesigned, need to be verified.

Y
Yibing Liu 已提交
108
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
109 110 111 112 113
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
114 115
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
116 117 118 119 120
    :param blank_id: ID of blank, default 0.
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
Y
Yibing Liu 已提交
121 122 123 124
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
125 126
    :param nproc: Whether the decoder used in multiprocesses.
    :type nproc: bool
127
    :return: Decoding log probabilities and result sentences in descending order.
Y
Yibing Liu 已提交
128 129
    :rtype: list
    '''
Y
Yibing Liu 已提交
130
    # dimension check
131 132
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
133
            raise ValueError("probs dimension mismatched with vocabulary")
Y
Yibing Liu 已提交
134
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
135

Y
Yibing Liu 已提交
136
    # blank_id check
Y
Yibing Liu 已提交
137 138 139 140
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

141 142 143 144 145 146
    # If the decoder called in the multiprocesses, then use the global scorer
    # instantiated in ctc_beam_search_decoder_nproc().
    if nproc is True:
        global ext_nproc_scorer
        ext_scoring_func = ext_nproc_scorer

Y
Yibing Liu 已提交
147 148
    ## initialize
    # the set containing selected prefixes
149 150
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
151 152

    ## extend prefix in loop
153
    for time_step in xrange(num_time_steps):
Y
Yibing Liu 已提交
154 155 156
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        prob = probs_seq[time_step]
        prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
        cutoff_len = len(prob_idx)
        #If pruning is enabled
        if (cutoff_prob < 1.0):
            prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
            cutoff_len = 0
            cum_prob = 0.0
            for i in xrange(len(prob_idx)):
                cum_prob += prob_idx[i][1]
                cutoff_len += 1
                if cum_prob >= cutoff_prob:
                    break
            prob_idx = prob_idx[0:cutoff_len]

Y
Yibing Liu 已提交
172 173 174 175
        for l in prefix_set_prev:
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

176 177 178 179
            # extend prefix by travering prob_idx
            for index in xrange(cutoff_len):
                c, prob_c = prob_idx[index][0], prob_idx[index][1]

Y
Yibing Liu 已提交
180
                if c == blank_id:
181
                    probs_b_cur[l] += prob_c * (
Y
Yibing Liu 已提交
182
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
183
                else:
184 185 186
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
187 188 189
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

190
                    if new_char == last_char:
191 192
                        probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
                        probs_nb_cur[l] += prob_c * probs_nb_prev[l]
193 194
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
195 196
                            score = 1.0
                        else:
197
                            prefix = l[1:]
Y
Yibing Liu 已提交
198
                            score = ext_scoring_func(prefix)
199
                        probs_nb_cur[l_plus] += score * prob_c * (
Y
Yibing Liu 已提交
200
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
201
                    else:
202
                        probs_nb_cur[l_plus] += prob_c * (
Y
Yibing Liu 已提交
203
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
204 205 206 207 208 209
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
210
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
211 212 213 214 215 216 217 218 219 220

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
221
        if prob > 0.0 and len(seq) > 1:
222
            result = seq[1:]
223 224 225
            # score last word by external scorer
            if (ext_scoring_func is not None) and (result[-1] != ' '):
                prob = prob * ext_scoring_func(result)
Y
Yibing Liu 已提交
226 227 228 229 230 231
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
232 233 234 235 236 237


def ctc_beam_search_decoder_nproc(probs_split,
                                  beam_size,
                                  vocabulary,
                                  blank_id=0,
238
                                  cutoff_prob=1.0,
239
                                  ext_scoring_func=None,
240
                                  num_processes=None):
Y
Yibing Liu 已提交
241
    '''Beam search decoder using multiple processes.
242

Y
Yibing Liu 已提交
243
    :param probs_seq: 3-D list with length batch_size, each element
244 245 246 247 248 249 250
                      is a 2-D list of  probabilities can be used by
                      ctc_beam_search_decoder.
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
251 252 253 254 255
    :param blank_id: ID of blank, default 0.
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 0, no pruning.
    :type cutoff_prob: float
256 257 258 259 260 261 262
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param num_processes: Number of processes, default None, equal to the
                 number of CPUs.
    :type num_processes: int
263
    :return: Decoding log probabilities and result sentences in descending order.
264 265 266 267 268 269 270 271 272
    :rtype: list

    '''

    if num_processes is None:
        num_processes = multiprocessing.cpu_count()
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

273 274 275 276 277
    # use global variable to pass the externnal scorer to beam search decoder
    global ext_nproc_scorer
    ext_nproc_scorer = ext_scoring_func
    nproc = True

278 279 280
    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
281 282
        args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
                nproc)
283 284 285 286 287 288 289 290
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
    beam_search_results = []
    for result in results:
        beam_search_results.append(result.get())
    return beam_search_results