decoder.py 9.9 KB
Newer Older
Y
Yibing Liu 已提交
1
"""Contains various CTC decoders."""
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
5 6 7

from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
8
from math import log
9
import multiprocessing
10 11


Y
Yibing Liu 已提交
12 13
def ctc_best_path_decoder(probs_seq, vocabulary):
    """Best path decoder, also called argmax decoder or greedy decoder.
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
41 42 43
def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
Y
Yibing Liu 已提交
44
                            blank_id,
45
                            cutoff_prob=1.0,
Y
Yibing Liu 已提交
46
                            ext_scoring_func=None,
47
                            nproc=False):
Y
Yibing Liu 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
    """Beam search decoder for CTC-trained network. It utilizes beam search
    to approximately select top best decoding labels and returning results
    in the descending order. The implementation is based on Prefix
    Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned. Two important modifications: 1) in the iterative computation
    of probabilities, the assignment operation is changed to accumulation for
    one prefix may comes from different paths; 2) the if condition "if l^+ not
    in A_prev then" after probabilities' computation is deprecated for it is
    hard to understand and seems unnecessary.

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
Y
Yibing Liu 已提交
61 62 63
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
64 65
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
66
    :param blank_id: ID of blank.
67 68 69 70
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
Y
Yibing Liu 已提交
71
    :param ext_scoring_func: External scoring function for
Y
Yibing Liu 已提交
72
                            partially decoded sentence, e.g. word count
Y
Yibing Liu 已提交
73 74
                            or language model.
    :type external_scoring_func: callable
75 76
    :param nproc: Whether the decoder used in multiprocesses.
    :type nproc: bool
Y
Yibing Liu 已提交
77 78
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
Y
Yibing Liu 已提交
79
    :rtype: list
Y
Yibing Liu 已提交
80
    """
Y
Yibing Liu 已提交
81
    # dimension check
82 83
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
Y
Yibing Liu 已提交
84 85
            raise ValueError("The shape of prob_seq does not match with the "
                             "shape of the vocabulary.")
Y
Yibing Liu 已提交
86

Y
Yibing Liu 已提交
87
    # blank_id check
Y
Yibing Liu 已提交
88
    if not blank_id < len(probs_seq[0]):
Y
Yibing Liu 已提交
89 90
        raise ValueError("blank_id shouldn't be greater than probs dimension")

91
    # If the decoder called in the multiprocesses, then use the global scorer
Y
Yibing Liu 已提交
92
    # instantiated in ctc_beam_search_decoder_batch().
93 94 95 96
    if nproc is True:
        global ext_nproc_scorer
        ext_scoring_func = ext_nproc_scorer

Y
Yibing Liu 已提交
97
    ## initialize
Y
Yibing Liu 已提交
98 99 100
    # prefix_set_prev: the set containing selected prefixes
    # probs_b_prev: prefixes' probability ending with blank in previous step
    # probs_nb_prev: prefixes' probability ending with non-blank in previous step
Y
Yibing Liu 已提交
101 102
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
103 104

    ## extend prefix in loop
Y
Yibing Liu 已提交
105 106 107 108 109 110 111
    for time_step in xrange(len(probs_seq)):
        # prefix_set_next: the set containing candidate prefixes
        # probs_b_cur: prefixes' probability ending with blank in current step
        # probs_nb_cur: prefixes' probability ending with non-blank in current step
        prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}

        prob_idx = list(enumerate(probs_seq[time_step]))
112 113
        cutoff_len = len(prob_idx)
        #If pruning is enabled
Y
Yibing Liu 已提交
114
        if cutoff_prob < 1.0:
115
            prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
Y
Yibing Liu 已提交
116
            cutoff_len, cum_prob = 0, 0.0
117 118 119 120 121 122 123
            for i in xrange(len(prob_idx)):
                cum_prob += prob_idx[i][1]
                cutoff_len += 1
                if cum_prob >= cutoff_prob:
                    break
            prob_idx = prob_idx[0:cutoff_len]

Y
Yibing Liu 已提交
124 125 126 127
        for l in prefix_set_prev:
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

128 129 130 131
            # extend prefix by travering prob_idx
            for index in xrange(cutoff_len):
                c, prob_c = prob_idx[index][0], prob_idx[index][1]

Y
Yibing Liu 已提交
132
                if c == blank_id:
133
                    probs_b_cur[l] += prob_c * (
Y
Yibing Liu 已提交
134
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
135
                else:
136 137 138
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
139 140 141
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

142
                    if new_char == last_char:
143 144
                        probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
                        probs_nb_cur[l] += prob_c * probs_nb_prev[l]
145 146
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
147 148
                            score = 1.0
                        else:
149
                            prefix = l[1:]
Y
Yibing Liu 已提交
150
                            score = ext_scoring_func(prefix)
151
                        probs_nb_cur[l_plus] += score * prob_c * (
Y
Yibing Liu 已提交
152
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
153
                    else:
154
                        probs_nb_cur[l_plus] += prob_c * (
Y
Yibing Liu 已提交
155
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
156 157 158 159 160 161
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
162
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
163 164 165 166 167 168 169 170 171

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
Y
Yibing Liu 已提交
172
    for seq, prob in prefix_set_prev.items():
173
        if prob > 0.0 and len(seq) > 1:
174
            result = seq[1:]
175 176 177
            # score last word by external scorer
            if (ext_scoring_func is not None) and (result[-1] != ' '):
                prob = prob * ext_scoring_func(result)
Y
Yibing Liu 已提交
178
            log_prob = log(prob)
Y
Yibing Liu 已提交
179
            beam_result.append((log_prob, result))
Y
Yibing Liu 已提交
180 181 182 183

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
184 185


Y
Yibing Liu 已提交
186
def ctc_beam_search_decoder_batch(probs_split,
187 188
                                  beam_size,
                                  vocabulary,
Y
Yibing Liu 已提交
189 190
                                  blank_id,
                                  num_processes,
191
                                  cutoff_prob=1.0,
Y
Yibing Liu 已提交
192 193
                                  ext_scoring_func=None):
    """CTC beam search decoder using multiple processes.
194

Y
Yibing Liu 已提交
195 196
    :param probs_seq: 3-D list with each element as an instance of 2-D list
                      of probabilities used by ctc_beam_search_decoder().
197 198 199 200 201
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
202
    :param blank_id: ID of blank.
203
    :type blank_id: int
Y
Yibing Liu 已提交
204 205
    :param num_processes: Number of parallel processes.
    :type num_processes: int
206
    :param cutoff_prob: Cutoff probability in pruning,
Y
Yibing Liu 已提交
207 208 209
                        default 1.0, no pruning.
    :param num_processes: Number of parallel processes.
    :type num_processes: int
210
    :type cutoff_prob: float
Y
Yibing Liu 已提交
211
    :param ext_scoring_func: External scoring function for
212
                            partially decoded sentence, e.g. word count
Y
Yibing Liu 已提交
213 214 215 216
                            or language model.
    :type external_scoring_function: callable
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
217
    :rtype: list
Y
Yibing Liu 已提交
218
    """
219 220 221
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

222 223 224 225 226
    # use global variable to pass the externnal scorer to beam search decoder
    global ext_nproc_scorer
    ext_nproc_scorer = ext_scoring_func
    nproc = True

227 228 229
    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
230 231
        args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
                nproc)
232 233 234 235
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
Y
Yibing Liu 已提交
236
    beam_search_results = [result.get() for result in results]
237
    return beam_search_results