decoder.py 7.1 KB
Newer Older
1 2 3 4 5 6
"""
    CTC-like decoder utilitis.
"""

from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
7 8
import copy
import kenlm
Y
Yibing Liu 已提交
9
import os
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
42
class Scorer(object):
43
    """
Y
Yibing Liu 已提交
44 45
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
46

Y
Yibing Liu 已提交
47 48 49 50 51 52 53 54 55 56 57
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
Y
Yibing Liu 已提交
58 59
        if not os.path.isfile(model_path):
            raise IOError("Invaid language model path: %s" % model_path)
Y
Yibing Liu 已提交
60 61
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
62 63 64 65 66 67
    # n-gram language model scoring
    def language_model_score(self, sentence):
        #log prob of last word
        log_cond_prob = list(
            self._language_model.full_scores(sentence, eos=False))[-1][0]
        return np.power(10, log_cond_prob)
Y
Yibing Liu 已提交
68

Y
Yibing Liu 已提交
69
    # word insertion term
Y
Yibing Liu 已提交
70 71 72 73 74
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
Y
Yibing Liu 已提交
75 76
    def evaluate(self, sentence):
        lm = self.language_model_score(sentence)
Y
Yibing Liu 已提交
77
        word_cnt = self.word_count(sentence)
Y
Yibing Liu 已提交
78
        score = np.power(lm, self._alpha) \
Y
Yibing Liu 已提交
79
                * np.power(word_cnt, self._beta)
Y
Yibing Liu 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
                            ext_scoring_func=None,
                            blank_id=0):
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
    the order of probabilities. The implementation is based on Prefix Beam
    Search(https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned, need to be verified.

Y
Yibing Liu 已提交
95
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
96 97 98 99 100
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
101 102
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
103 104 105 106 107 108 109 110 111 112
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''
Y
Yibing Liu 已提交
113
    # dimension check
114 115 116
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
            raise ValueError("probs dimension mismatchedd with vocabulary")
Y
Yibing Liu 已提交
117
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
118

Y
Yibing Liu 已提交
119
    # blank_id check
Y
Yibing Liu 已提交
120 121 122 123 124 125
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

    ## initialize
    # the set containing selected prefixes
126 127
    prefix_set_prev = {'\t': 1.0}
    probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
Y
Yibing Liu 已提交
128 129

    ## extend prefix in loop
Y
Yibing Liu 已提交
130
    for time_step in range(num_time_steps):
Y
Yibing Liu 已提交
131 132 133 134 135 136 137 138 139 140 141
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
        for l in prefix_set_prev:
            prob = probs_seq[time_step]
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

            # extend prefix by travering vocabulary
            for c in range(0, probs_dim):
                if c == blank_id:
Y
Yibing Liu 已提交
142 143
                    probs_b_cur[l] += prob[c] * (
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
144
                else:
145 146 147
                    last_char = l[-1]
                    new_char = vocabulary[c]
                    l_plus = l + new_char
Y
Yibing Liu 已提交
148 149 150
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

151
                    if new_char == last_char:
Y
Yibing Liu 已提交
152 153
                        probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
                        probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
154 155
                    elif new_char == ' ':
                        if (ext_scoring_func is None) or (len(l) == 1):
Y
Yibing Liu 已提交
156 157
                            score = 1.0
                        else:
158
                            prefix = l[1:]
Y
Yibing Liu 已提交
159
                            score = ext_scoring_func(prefix)
Y
Yibing Liu 已提交
160
                        probs_nb_cur[l_plus] += score * prob[c] * (
Y
Yibing Liu 已提交
161
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
162 163
                    else:
                        probs_nb_cur[l_plus] += prob[c] * (
Y
Yibing Liu 已提交
164
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
165 166 167 168 169 170
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
171
        probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
Y
Yibing Liu 已提交
172 173 174 175 176 177 178 179 180 181 182

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
        if prob > 0.0:
183
            result = seq[1:]
Y
Yibing Liu 已提交
184 185 186 187 188 189
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result