decoder.py 9.0 KB
Newer Older
1 2 3 4
"""
    CTC-like decoder utilitis.
"""

5
import os
6 7
from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
8 9
import copy
import kenlm
10
import multiprocessing
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
43
class Scorer(object):
44
    """
Y
Yibing Liu 已提交
45 46
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
47

Y
Yibing Liu 已提交
48 49 50 51 52 53 54 55 56 57 58
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
Y
Yibing Liu 已提交
59 60
        if not os.path.isfile(model_path):
            raise IOError("Invaid language model path: %s" % model_path)
Y
Yibing Liu 已提交
61 62
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
63 64 65 66 67 68
    # n-gram language model scoring
    def language_model_score(self, sentence):
        #log prob of last word
        log_cond_prob = list(
            self._language_model.full_scores(sentence, eos=False))[-1][0]
        return np.power(10, log_cond_prob)
Y
Yibing Liu 已提交
69

Y
Yibing Liu 已提交
70
    # word insertion term
Y
Yibing Liu 已提交
71 72 73 74 75
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
Y
Yibing Liu 已提交
76 77
    def evaluate(self, sentence):
        lm = self.language_model_score(sentence)
Y
Yibing Liu 已提交
78
        word_cnt = self.word_count(sentence)
Y
Yibing Liu 已提交
79
        score = np.power(lm, self._alpha) \
Y
Yibing Liu 已提交
80
                * np.power(word_cnt, self._beta)
Y
Yibing Liu 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
                            ext_scoring_func=None,
                            blank_id=0):
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
    the order of probabilities. The implementation is based on Prefix Beam
    Search(https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned, need to be verified.

Y
Yibing Liu 已提交
96
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
97 98 99 100 101
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
102 103
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
104 105 106 107 108 109 110 111 112 113
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''
Y
Yibing Liu 已提交
114
    # dimension check
115 116 117
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
            raise ValueError("probs dimension mismatchedd with vocabulary")
Y
Yibing Liu 已提交
118
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
119

Y
Yibing Liu 已提交
120
    # blank_id check
Y
Yibing Liu 已提交
121 122 123 124 125 126
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

    ## initialize
    # the set containing selected prefixes
127 128
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
129 130

    ## extend prefix in loop
Y
Yibing Liu 已提交
131
    for time_step in range(num_time_steps):
Y
Yibing Liu 已提交
132 133 134 135 136 137 138 139 140 141 142
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
        for l in prefix_set_prev:
            prob = probs_seq[time_step]
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

            # extend prefix by travering vocabulary
            for c in range(0, probs_dim):
                if c == blank_id:
Y
Yibing Liu 已提交
143 144
                    probs_b_cur[l] += prob[c] * (
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
145
                else:
146 147 148
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
149 150 151
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

152
                    if new_char == last_char:
Y
Yibing Liu 已提交
153 154
                        probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
                        probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
155 156
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
157 158
                            score = 1.0
                        else:
159
                            prefix = l[1:]
Y
Yibing Liu 已提交
160
                            score = ext_scoring_func(prefix)
Y
Yibing Liu 已提交
161
                        probs_nb_cur[l_plus] += score * prob[c] * (
Y
Yibing Liu 已提交
162
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
163 164
                    else:
                        probs_nb_cur[l_plus] += prob[c] * (
Y
Yibing Liu 已提交
165
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
166 167 168 169 170 171
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
172
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
173 174 175 176 177 178 179 180 181 182 183

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
        if prob > 0.0:
184
            result = seq[1:]
Y
Yibing Liu 已提交
185 186 187 188 189 190
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241


def ctc_beam_search_decoder_nproc(probs_split,
                                  beam_size,
                                  vocabulary,
                                  ext_scoring_func=None,
                                  blank_id=0,
                                  num_processes=None):
    '''
    Beam search decoder using multiple processes.

    :param probs_seq: 3-D list with length num_time_steps, each element
                      is a 2-D list of  probabilities can be used by
                      ctc_beam_search_decoder.

    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :param num_processes: Number of processes, default None, equal to the
                 number of CPUs.
    :type num_processes: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''

    if num_processes is None:
        num_processes = multiprocessing.cpu_count()
    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
        args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id)
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
    beam_search_results = []
    for result in results:
        beam_search_results.append(result.get())
    return beam_search_results