decoders_deprecated.py 9.8 KB
Newer Older
Y
Yibing Liu 已提交
1
"""Contains various CTC decoders."""
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5 6 7

from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
8
from math import log
9
import multiprocessing
10 11


12 13 14
def ctc_greedy_decoder(probs_seq, vocabulary):
    """CTC greedy (best path) decoder.

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
42 43 44
def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
45
                            cutoff_prob=1.0,
Y
Yibing Liu 已提交
46
                            cutoff_top_n=40,
Y
Yibing Liu 已提交
47
                            ext_scoring_func=None,
48
                            nproc=False):
49 50 51 52 53 54
    """CTC Beam search decoder.

    It utilizes beam search to approximately select top best decoding
    labels and returning results in the descending order.
    The implementation is based on Prefix Beam Search
    (https://arxiv.org/abs/1408.2873), and the unclear part is
Y
Yibing Liu 已提交
55 56 57 58 59 60 61 62 63
    redesigned. Two important modifications: 1) in the iterative computation
    of probabilities, the assignment operation is changed to accumulation for
    one prefix may comes from different paths; 2) the if condition "if l^+ not
    in A_prev then" after probabilities' computation is deprecated for it is
    hard to understand and seems unnecessary.

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
Y
Yibing Liu 已提交
64 65 66
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
67 68
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
69 70 71
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
Y
Yibing Liu 已提交
72
    :param ext_scoring_func: External scoring function for
Y
Yibing Liu 已提交
73
                            partially decoded sentence, e.g. word count
Y
Yibing Liu 已提交
74 75
                            or language model.
    :type external_scoring_func: callable
76 77
    :param nproc: Whether the decoder used in multiprocesses.
    :type nproc: bool
Y
Yibing Liu 已提交
78 79
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
Y
Yibing Liu 已提交
80
    :rtype: list
Y
Yibing Liu 已提交
81
    """
Y
Yibing Liu 已提交
82
    # dimension check
83 84
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
Y
Yibing Liu 已提交
85 86
            raise ValueError("The shape of prob_seq does not match with the "
                             "shape of the vocabulary.")
Y
Yibing Liu 已提交
87

Y
Yibing Liu 已提交
88 89
    # blank_id assign
    blank_id = len(vocabulary)
Y
Yibing Liu 已提交
90

91
    # If the decoder called in the multiprocesses, then use the global scorer
Y
Yibing Liu 已提交
92
    # instantiated in ctc_beam_search_decoder_batch().
93 94 95 96
    if nproc is True:
        global ext_nproc_scorer
        ext_scoring_func = ext_nproc_scorer

Y
Yibing Liu 已提交
97
    ## initialize
Y
Yibing Liu 已提交
98 99 100
    # prefix_set_prev: the set containing selected prefixes
    # probs_b_prev: prefixes' probability ending with blank in previous step
    # probs_nb_prev: prefixes' probability ending with non-blank in previous step
Y
Yibing Liu 已提交
101 102
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
103 104

    ## extend prefix in loop
Y
Yibing Liu 已提交
105 106 107 108 109 110 111
    for time_step in xrange(len(probs_seq)):
        # prefix_set_next: the set containing candidate prefixes
        # probs_b_cur: prefixes' probability ending with blank in current step
        # probs_nb_cur: prefixes' probability ending with non-blank in current step
        prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}

        prob_idx = list(enumerate(probs_seq[time_step]))
112 113
        cutoff_len = len(prob_idx)
        #If pruning is enabled
Y
Yibing Liu 已提交
114
        if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
115
            prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
Y
Yibing Liu 已提交
116
            cutoff_len, cum_prob = 0, 0.0
117 118 119 120 121
            for i in xrange(len(prob_idx)):
                cum_prob += prob_idx[i][1]
                cutoff_len += 1
                if cum_prob >= cutoff_prob:
                    break
122
            cutoff_len = min(cutoff_len, cutoff_top_n)
123 124
            prob_idx = prob_idx[0:cutoff_len]

Y
Yibing Liu 已提交
125 126 127 128
        for l in prefix_set_prev:
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

129 130 131 132
            # extend prefix by travering prob_idx
            for index in xrange(cutoff_len):
                c, prob_c = prob_idx[index][0], prob_idx[index][1]

Y
Yibing Liu 已提交
133
                if c == blank_id:
134
                    probs_b_cur[l] += prob_c * (
Y
Yibing Liu 已提交
135
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
136
                else:
137 138 139
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
140 141 142
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

143
                    if new_char == last_char:
144 145
                        probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
                        probs_nb_cur[l] += prob_c * probs_nb_prev[l]
146 147
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
148 149
                            score = 1.0
                        else:
150
                            prefix = l[1:]
Y
Yibing Liu 已提交
151
                            score = ext_scoring_func(prefix)
152
                        probs_nb_cur[l_plus] += score * prob_c * (
Y
Yibing Liu 已提交
153
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
154
                    else:
155
                        probs_nb_cur[l_plus] += prob_c * (
Y
Yibing Liu 已提交
156
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
157 158 159 160 161 162
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
163
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
164 165 166 167 168 169 170 171 172

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
Y
Yibing Liu 已提交
173
    for seq, prob in prefix_set_prev.items():
174
        if prob > 0.0 and len(seq) > 1:
175
            result = seq[1:]
176 177 178
            # score last word by external scorer
            if (ext_scoring_func is not None) and (result[-1] != ' '):
                prob = prob * ext_scoring_func(result)
Y
Yibing Liu 已提交
179
            log_prob = log(prob)
Y
Yibing Liu 已提交
180
            beam_result.append((log_prob, result))
181 182
        else:
            beam_result.append((float('-inf'), ''))
Y
Yibing Liu 已提交
183 184 185 186

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
187 188


Y
Yibing Liu 已提交
189
def ctc_beam_search_decoder_batch(probs_split,
190 191
                                  beam_size,
                                  vocabulary,
Y
Yibing Liu 已提交
192
                                  num_processes,
193
                                  cutoff_prob=1.0,
Y
Yibing Liu 已提交
194
                                  cutoff_top_n=40,
Y
Yibing Liu 已提交
195 196
                                  ext_scoring_func=None):
    """CTC beam search decoder using multiple processes.
197

Y
Yibing Liu 已提交
198 199
    :param probs_seq: 3-D list with each element as an instance of 2-D list
                      of probabilities used by ctc_beam_search_decoder().
200 201 202 203 204
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
205 206
    :param num_processes: Number of parallel processes.
    :type num_processes: int
207
    :param cutoff_prob: Cutoff probability in pruning,
Y
Yibing Liu 已提交
208
                        default 1.0, no pruning.
209
    :type cutoff_prob: float
Y
Yibing Liu 已提交
210 211 212
    :param num_processes: Number of parallel processes.
    :type num_processes: int
    :param ext_scoring_func: External scoring function for
213
                            partially decoded sentence, e.g. word count
Y
Yibing Liu 已提交
214 215 216 217
                            or language model.
    :type external_scoring_function: callable
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
218
    :rtype: list
Y
Yibing Liu 已提交
219
    """
220 221 222
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

223 224 225 226 227
    # use global variable to pass the externnal scorer to beam search decoder
    global ext_nproc_scorer
    ext_nproc_scorer = ext_scoring_func
    nproc = True

228 229 230
    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
231 232
        args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
                None, nproc)
233 234 235 236
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
Y
Yibing Liu 已提交
237
    beam_search_results = [result.get() for result in results]
238
    return beam_search_results